Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions tests/experimental/test_async_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@

import itertools
import queue
from unittest.mock import patch

import numpy as np
import pytest
import torch
from datasets import load_dataset
from transformers import AutoTokenizer

from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer
from trl.experimental.async_grpo import async_grpo_trainer as async_grpo_trainer_module
from trl.experimental.async_grpo.async_rollout_worker import RolloutSample

from ..testing_utils import TrlTestCase
Expand Down Expand Up @@ -92,6 +95,81 @@ def send_weights(self, iterator):


class TestAsyncGRPOTrainer(TrlTestCase):
def test_init_defaults_to_auto_dtype(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train")
tokenizer = AutoTokenizer.from_pretrained(model_id)
training_args = AsyncGRPOConfig(output_dir=self.tmp_dir, num_generations=3, report_to="none")
original_from_pretrained = async_grpo_trainer_module.AutoModelForCausalLM.from_pretrained

with patch.object(
async_grpo_trainer_module.AutoModelForCausalLM,
"from_pretrained",
wraps=original_from_pretrained,
) as mock_from_pretrained:
AsyncGRPOTrainer(
model=model_id,
args=training_args,
reward_funcs=dummy_reward_func,
train_dataset=dataset,
processing_class=tokenizer,
rollout_worker=_StubRolloutWorker(tokenizer, dataset, num_generations=3),
)

assert mock_from_pretrained.call_args.kwargs["dtype"] == "auto"
assert mock_from_pretrained.call_args.kwargs["device_map"] is None

def test_init_converts_model_init_dtype_string(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train")
tokenizer = AutoTokenizer.from_pretrained(model_id)
training_args = AsyncGRPOConfig(
output_dir=self.tmp_dir,
num_generations=3,
model_init_kwargs={"dtype": "bfloat16"},
report_to="none",
)
original_from_pretrained = async_grpo_trainer_module.AutoModelForCausalLM.from_pretrained

with patch.object(
async_grpo_trainer_module.AutoModelForCausalLM,
"from_pretrained",
wraps=original_from_pretrained,
) as mock_from_pretrained:
trainer = AsyncGRPOTrainer(
model=model_id,
args=training_args,
reward_funcs=dummy_reward_func,
train_dataset=dataset,
processing_class=tokenizer,
rollout_worker=_StubRolloutWorker(tokenizer, dataset, num_generations=3),
)

assert mock_from_pretrained.call_args.kwargs["dtype"] == torch.bfloat16
assert mock_from_pretrained.call_args.kwargs["device_map"] is None
assert next(trainer.model.parameters()).dtype == torch.bfloat16

def test_init_rejects_invalid_model_init_dtype(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train")
tokenizer = AutoTokenizer.from_pretrained(model_id)
training_args = AsyncGRPOConfig(
output_dir=self.tmp_dir,
num_generations=3,
model_init_kwargs={"dtype": "not_a_dtype"},
report_to="none",
)

with pytest.raises(ValueError, match="Invalid `dtype` passed to `AsyncGRPOConfig`"):
AsyncGRPOTrainer(
model=model_id,
args=training_args,
reward_funcs=dummy_reward_func,
train_dataset=dataset,
processing_class=tokenizer,
rollout_worker=_StubRolloutWorker(tokenizer, dataset, num_generations=3),
)

def test_init_minimal(self):
# Test that AsyncGRPOTrainer can be instantiated with only model, reward_model and train_dataset
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
Expand Down
16 changes: 16 additions & 0 deletions trl/experimental/async_grpo/async_grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from dataclasses import dataclass, field
from typing import Any

from trl.trainer.base_config import _BaseConfig

Expand Down Expand Up @@ -74,6 +75,12 @@ class AsyncGRPOConfig(_BaseConfig):
weight_sync_steps (`int`, *optional*, defaults to `1`):
Number of training steps between weight synchronizations to the vLLM server.

> Other parameters

model_init_kwargs (`dict[str, Any]`, *optional*):
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
string.

> Parameters that control the logging

log_completions (`bool`, *optional*, defaults to `False`):
Expand All @@ -89,6 +96,8 @@ class AsyncGRPOConfig(_BaseConfig):
> - `learning_rate`: Defaults to `1e-6` instead of `5e-5`.
"""

_VALID_DICT_FIELDS = _BaseConfig._VALID_DICT_FIELDS + ["model_init_kwargs"]

# Parameters whose default values are overridden from TrainingArguments
learning_rate: float = field(
default=1e-6,
Expand Down Expand Up @@ -184,6 +193,13 @@ class AsyncGRPOConfig(_BaseConfig):
default=1,
metadata={"help": "Number of training steps between weight synchronizations to the vLLM server."},
)
model_init_kwargs: dict[str, Any] | str | None = field(
default=None,
metadata={
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
"model from a string."
},
)

# Parameters that control the logging
log_completions: bool = field(
Expand Down
21 changes: 20 additions & 1 deletion trl/experimental/async_grpo/async_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,26 @@ def __init__(

# Model
model_name = model
model = AutoModelForCausalLM.from_pretrained(model, device_map=None, dtype=torch.float32)
model_init_kwargs = {} if self.args.model_init_kwargs is None else self.args.model_init_kwargs.copy()
dtype = model_init_kwargs.get("dtype", "auto")
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
pass
elif isinstance(dtype, str):
try:
dtype = getattr(torch, dtype)
except AttributeError as exc:
raise ValueError(
"Invalid `dtype` passed to `AsyncGRPOConfig`. Expected either 'auto' or a string "
f"representing a `torch.dtype` (e.g., 'float32'), but got {dtype}."
) from exc
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing isinstance check after getattr allows non-dtype values

Low Severity

After getattr(torch, dtype) resolves a string, the code only catches AttributeError but never verifies the result is actually a torch.dtype. A string like "tensor" or "cuda" would successfully resolve via getattr (returning torch.tensor function or torch.cuda module) without raising AttributeError, and the non-dtype value would be silently forwarded to from_pretrained. The CPO and BCO trainers guard against this with an explicit not isinstance(dtype, torch.dtype) check after the getattr call.

Fix in Cursor Fix in Web

else:
raise ValueError(
"Invalid `dtype` passed to `AsyncGRPOConfig`. Expected either 'auto', `None`, or a `torch.dtype`, "
f"but got {dtype}."
)
model_init_kwargs["dtype"] = dtype
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", None)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

# Processing class
if processing_class is None:
Expand Down