From 8b07d42faa3a88daf8ca4cd082deca5b7c71afb7 Mon Sep 17 00:00:00 2001 From: Kirill Dubovikov Date: Tue, 31 Mar 2026 17:35:30 +0300 Subject: [PATCH] fix(async-grpo): honor model init dtype --- tests/experimental/test_async_grpo_trainer.py | 78 +++++++++++++++++++ .../async_grpo/async_grpo_config.py | 16 ++++ .../async_grpo/async_grpo_trainer.py | 21 ++++- 3 files changed, 114 insertions(+), 1 deletion(-) diff --git a/tests/experimental/test_async_grpo_trainer.py b/tests/experimental/test_async_grpo_trainer.py index cc5413a6213..1ee5d197b47 100644 --- a/tests/experimental/test_async_grpo_trainer.py +++ b/tests/experimental/test_async_grpo_trainer.py @@ -14,13 +14,16 @@ import itertools import queue +from unittest.mock import patch import numpy as np +import pytest import torch from datasets import load_dataset from transformers import AutoTokenizer from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer +from trl.experimental.async_grpo import async_grpo_trainer as async_grpo_trainer_module from trl.experimental.async_grpo.async_rollout_worker import RolloutSample from ..testing_utils import TrlTestCase @@ -92,6 +95,81 @@ def send_weights(self, iterator): class TestAsyncGRPOTrainer(TrlTestCase): + def test_init_defaults_to_auto_dtype(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train") + tokenizer = AutoTokenizer.from_pretrained(model_id) + training_args = AsyncGRPOConfig(output_dir=self.tmp_dir, num_generations=3, report_to="none") + original_from_pretrained = async_grpo_trainer_module.AutoModelForCausalLM.from_pretrained + + with patch.object( + async_grpo_trainer_module.AutoModelForCausalLM, + "from_pretrained", + wraps=original_from_pretrained, + ) as mock_from_pretrained: + AsyncGRPOTrainer( + model=model_id, + args=training_args, + reward_funcs=dummy_reward_func, + train_dataset=dataset, + processing_class=tokenizer, + rollout_worker=_StubRolloutWorker(tokenizer, dataset, num_generations=3), + ) + + assert mock_from_pretrained.call_args.kwargs["dtype"] == "auto" + assert mock_from_pretrained.call_args.kwargs["device_map"] is None + + def test_init_converts_model_init_dtype_string(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train") + tokenizer = AutoTokenizer.from_pretrained(model_id) + training_args = AsyncGRPOConfig( + output_dir=self.tmp_dir, + num_generations=3, + model_init_kwargs={"dtype": "bfloat16"}, + report_to="none", + ) + original_from_pretrained = async_grpo_trainer_module.AutoModelForCausalLM.from_pretrained + + with patch.object( + async_grpo_trainer_module.AutoModelForCausalLM, + "from_pretrained", + wraps=original_from_pretrained, + ) as mock_from_pretrained: + trainer = AsyncGRPOTrainer( + model=model_id, + args=training_args, + reward_funcs=dummy_reward_func, + train_dataset=dataset, + processing_class=tokenizer, + rollout_worker=_StubRolloutWorker(tokenizer, dataset, num_generations=3), + ) + + assert mock_from_pretrained.call_args.kwargs["dtype"] == torch.bfloat16 + assert mock_from_pretrained.call_args.kwargs["device_map"] is None + assert next(trainer.model.parameters()).dtype == torch.bfloat16 + + def test_init_rejects_invalid_model_init_dtype(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train") + tokenizer = AutoTokenizer.from_pretrained(model_id) + training_args = AsyncGRPOConfig( + output_dir=self.tmp_dir, + num_generations=3, + model_init_kwargs={"dtype": "not_a_dtype"}, + report_to="none", + ) + + with pytest.raises(ValueError, match="Invalid `dtype` passed to `AsyncGRPOConfig`"): + AsyncGRPOTrainer( + model=model_id, + args=training_args, + reward_funcs=dummy_reward_func, + train_dataset=dataset, + processing_class=tokenizer, + rollout_worker=_StubRolloutWorker(tokenizer, dataset, num_generations=3), + ) + def test_init_minimal(self): # Test that AsyncGRPOTrainer can be instantiated with only model, reward_model and train_dataset model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" diff --git a/trl/experimental/async_grpo/async_grpo_config.py b/trl/experimental/async_grpo/async_grpo_config.py index 2afd760e7fc..b4d33d63ce2 100644 --- a/trl/experimental/async_grpo/async_grpo_config.py +++ b/trl/experimental/async_grpo/async_grpo_config.py @@ -13,6 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field +from typing import Any from trl.trainer.base_config import _BaseConfig @@ -74,6 +75,12 @@ class AsyncGRPOConfig(_BaseConfig): weight_sync_steps (`int`, *optional*, defaults to `1`): Number of training steps between weight synchronizations to the vLLM server. + > Other parameters + + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a + string. + > Parameters that control the logging log_completions (`bool`, *optional*, defaults to `False`): @@ -89,6 +96,8 @@ class AsyncGRPOConfig(_BaseConfig): > - `learning_rate`: Defaults to `1e-6` instead of `5e-5`. """ + _VALID_DICT_FIELDS = _BaseConfig._VALID_DICT_FIELDS + ["model_init_kwargs"] + # Parameters whose default values are overridden from TrainingArguments learning_rate: float = field( default=1e-6, @@ -184,6 +193,13 @@ class AsyncGRPOConfig(_BaseConfig): default=1, metadata={"help": "Number of training steps between weight synchronizations to the vLLM server."}, ) + model_init_kwargs: dict[str, Any] | str | None = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " + "model from a string." + }, + ) # Parameters that control the logging log_completions: bool = field( diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index aca72c73596..ab31d8b6691 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -289,7 +289,26 @@ def __init__( # Model model_name = model - model = AutoModelForCausalLM.from_pretrained(model, device_map=None, dtype=torch.float32) + model_init_kwargs = {} if self.args.model_init_kwargs is None else self.args.model_init_kwargs.copy() + dtype = model_init_kwargs.get("dtype", "auto") + if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: + pass + elif isinstance(dtype, str): + try: + dtype = getattr(torch, dtype) + except AttributeError as exc: + raise ValueError( + "Invalid `dtype` passed to `AsyncGRPOConfig`. Expected either 'auto' or a string " + f"representing a `torch.dtype` (e.g., 'float32'), but got {dtype}." + ) from exc + else: + raise ValueError( + "Invalid `dtype` passed to `AsyncGRPOConfig`. Expected either 'auto', `None`, or a `torch.dtype`, " + f"but got {dtype}." + ) + model_init_kwargs["dtype"] = dtype + model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", None) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) # Processing class if processing_class is None: