diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 624ba22b91b..054a6c286b0 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -39,6 +39,8 @@ title: Distributing Training - local: use_model title: Using Trained Models + - local: per_sample_tools + title: Per-Sample Tool Filtering title: How-to guides - sections: - local: deepspeed_integration diff --git a/docs/source/per_sample_tools.md b/docs/source/per_sample_tools.md new file mode 100644 index 00000000000..00a6bb9c892 --- /dev/null +++ b/docs/source/per_sample_tools.md @@ -0,0 +1,108 @@ +# Per-Sample Tool Filtering in GRPOTrainer + +## Motivation + +In many agentic settings, different training samples require different subsets of tools. +For example, a math question might only need a `calculator`, while a translation task +might only need a `translator`. Exposing all tools on every sample can confuse the model +and dilute the training signal. + +`GRPOTrainer` automatically detects a `tools` column in your dataset and uses it to restrict +**which tools are available per sample**, drawn from the global `tools` pool. + +## How It Works + +1. **Global tool pool** — You pass the full set of tools to the trainer via `tools=[...]` as before. +2. **Per-sample tool column** — Your dataset includes a `"tools"` column containing a list of tool + **names** (strings matching `tool.__name__`) allowed for each sample. +3. **Automatic filtering** — For each rollout, only the specified tools appear in the model's + system prompt (chat template) and are available for execution. If the column value is `None` + for a sample, all tools are used as a fallback. + +## Example + +```python +from datasets import Dataset +from trl import GRPOTrainer, GRPOConfig + + +# Define tool functions +def calculator(number_a: float, operation: str, number_b: float) -> str: + """Perform a basic arithmetic operation on two numbers. + + Args: + number_a: The first operand. + operation: The operation to perform. One of '+', '-', '*', '/'. + number_b: The second operand. + + Returns: + The result of the operation as a string. + + Raises: + ValueError: If the operation is not supported or division by zero is attempted. + """ + try: + number_a = float(number_a) + except (TypeError, ValueError): + raise TypeError(f"number_a must be convertible to a number, got {type(number_a).__name__!r}") + try: + number_b = float(number_b) + except (TypeError, ValueError): + raise TypeError(f"number_b must be convertible to a number, got {type(number_b).__name__!r}") + if operation == "+": + return str(number_a + number_b) + elif operation == "-": + return str(number_a - number_b) + elif operation == "*": + return str(number_a * number_b) + elif operation == "/": + if number_b == 0: + raise ValueError("Division by zero is not allowed.") + return str(number_a / number_b) + else: + raise ValueError(f"Unsupported operation '{operation}'. Use one of: +, -, *, /") + + +def translator(text: str, target_language: str) -> str: + """Translate text to a target language. + + Args: + text: The text to translate. + target_language: ISO language code, e.g. 'fr', 'es', 'de'. + + Returns: + The translated text. + """ + # Placeholder — in practice, call a translation API + return f"[{target_language}] {text}" + + +# Build dataset with per-sample tool column +dataset = Dataset.from_dict({ + "prompt": [ + [{"role": "user", "content": "What is 123 * 456?"}], + [{"role": "user", "content": "Translate 'good morning' to French."}], + [{"role": "user", "content": "Compute 2^10 and translate the result to Spanish."}], + ], + "tools": [ + ["calculator"], # only calculator available + ["translator"], # only translator available + ["calculator", "translator"], # both available + ], +}) + +# The trainer automatically detects the "tools" column and applies per-sample filtering +trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + reward_funcs=my_reward, + tools=[calculator, translator], + train_dataset=dataset, +) + +trainer.train() +``` + +## Backward Compatibility + +When the dataset has no `tools` column, behavior is identical to the existing API — all +tools in the `tools` list are used for every sample. diff --git a/tests/test_grpo_tools_column.py b/tests/test_grpo_tools_column.py new file mode 100644 index 00000000000..651cbe84cb6 --- /dev/null +++ b/tests/test_grpo_tools_column.py @@ -0,0 +1,660 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the per-sample tool filtering feature in GRPOTrainer. + +When a dataset contains a `tools` column, GRPOTrainer automatically uses it to restrict +which tools each sample can call during rollout. +""" + +import os +from unittest.mock import patch + +import pytest +import torch +import transformers +from datasets import Dataset, load_dataset +from packaging.version import Version + +from trl import GRPOConfig, GRPOTrainer + +from .testing_utils import TrlTestCase, require_jmespath + + +# ────────────────────────────────────────────────────────────────────── +# Tool definitions +# ────────────────────────────────────────────────────────────────────── + + +def multiply_tool(a: int, b: int) -> int: + """ + Multiplies two integers. + + Args: + a: The first integer. + b: The second integer. + + Returns: + The product of the two integers. + """ + return a * b + + +def add_tool(a: int, b: int) -> int: + """ + Adds two integers. + + Args: + a: The first integer. + b: The second integer. + + Returns: + The sum of the two integers. + """ + return a + b + + +async def async_add_tool(a: int, b: int) -> int: + """ + Asynchronously adds two integers. + + Args: + a: The first integer. + b: The second integer. + + Returns: + The sum of the two integers. + """ + return a + b + + +# ────────────────────────────────────────────────────────────────────── +# Unit-level tests (no model loading, fast) +# ────────────────────────────────────────────────────────────────────── + + +class TestToolsColumnValidation(TrlTestCase): + """Test that per-sample tool filtering via the `tools` dataset column works correctly.""" + + def _make_conversational_dataset(self, tool_names_per_sample): + """Helper: create a minimal conversational dataset with a tools column.""" + prompts = [[{"role": "user", "content": f"Question {i}"}] for i in range(len(tool_names_per_sample))] + return Dataset.from_dict({"prompt": prompts, "tools": tool_names_per_sample}) + + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0"), + reason="Tool parsing is not supported in transformers versions below 5.0.0", + strict=True, + ) + @require_jmespath + def test_valid_tools_column_passes_validation(self): + """A dataset with a tools column and matching tool pool should init without errors.""" + dataset = self._make_conversational_dataset([["multiply_tool"], ["add_tool"], ["multiply_tool", "add_tool"]]) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=GRPOConfig( + output_dir=self.tmp_dir, + report_to="none", + per_device_train_batch_size=3, + num_generations=3, + ), + train_dataset=dataset, + tools=[multiply_tool, add_tool], + ) + assert {t.__name__ for t in trainer.tools} == {"multiply_tool", "add_tool"} + + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0"), + reason="Tool parsing is not supported in transformers versions below 5.0.0", + strict=True, + ) + @require_jmespath + def test_no_tools_column_backward_compat(self): + """When the dataset has no tools column, trainer behaves as before (all tools used).""" + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=GRPOConfig( + output_dir=self.tmp_dir, + report_to="none", + per_device_train_batch_size=3, + num_generations=3, + ), + train_dataset=dataset, + tools=[multiply_tool], + ) + assert {t.__name__ for t in trainer.tools} == {"multiply_tool"} + + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0"), + reason="Tool parsing is not supported in transformers versions below 5.0.0", + strict=True, + ) + @require_jmespath + def test_tools_accessible_per_sample(self): + """Each per-sample tool dict should contain the correct callables for its sample.""" + dataset = self._make_conversational_dataset([["multiply_tool", "add_tool"]]) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=GRPOConfig( + output_dir=self.tmp_dir, + report_to="none", + per_device_train_batch_size=3, + num_generations=3, + ), + train_dataset=dataset, + tools=[multiply_tool, add_tool], + ) + assert trainer._sync_tool_dicts[0]["multiply_tool"] is multiply_tool + assert trainer._sync_tool_dicts[0]["add_tool"] is add_tool + + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0"), + reason="Tool parsing is not supported in transformers versions below 5.0.0", + strict=True, + ) + @require_jmespath + def test_signature_columns_include_tools(self): + """The `tools` column should always be included in _signature_columns.""" + dataset = self._make_conversational_dataset([["multiply_tool"]]) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=GRPOConfig( + output_dir=self.tmp_dir, + report_to="none", + per_device_train_batch_size=3, + num_generations=3, + ), + train_dataset=dataset, + tools=[multiply_tool], + ) + trainer._set_signature_columns_if_needed() + assert "tools" in trainer._signature_columns + + +# ────────────────────────────────────────────────────────────────────── +# Integration tests (model loading + training with fake_generate) +# ────────────────────────────────────────────────────────────────────── + + +class TestToolsColumnTraining(TrlTestCase): + """End-to-end tests for training with per-sample tool filtering.""" + + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0"), + reason="Tool parsing is not supported in transformers versions below 5.0.0", + strict=True, + ) + @require_jmespath + def test_training_with_tools_column(self): + """Train with a `tools` dataset column and verify per-sample filtering works end-to-end. + + We create a 3-sample dataset where: + - Sample 0: only multiply_tool available → model calls multiply_tool (succeeds) + - Sample 1: only multiply_tool available → model calls multiply_tool (fails, wrong arg) + - Sample 2: only multiply_tool available → model returns plain text (no tool call) + """ + dataset = Dataset.from_dict( + { + "prompt": [ + [{"role": "user", "content": "What is 3 times 4?"}], + [{"role": "user", "content": "Multiply 5 and 6."}], + [{"role": "user", "content": "Tell me a joke."}], + ], + "tools": [ + ["multiply_tool"], + ["multiply_tool"], + ["multiply_tool"], + ], + } + ) + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=128, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + tools=[multiply_tool, add_tool], # global pool has both; dataset column restricts per sample + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + def fake_generate(input_ids, **kwargs): + if input_ids.shape[0] == 3: # first call + # fmt: off + completion_ids = torch.tensor( + [ + # '\n{"name": "multiply_tool", "arguments": {"a": 3, "b": 4}}\n<|im_end|>' + [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645], + # invalid tool call: wrong argument name "c" instead of "b" + # '\n{"name": "multiply_tool", "arguments": {"a": 3, "c": 4}}\n<|im_end|>' + [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 66, 788, 220, 19, 11248, 151658, 151645], + # "I don't know any tool<|im_end|>" + [40, 1513, 944, 1414, 894, 5392, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643], + ], + device=input_ids.device, + ) + # fmt: on + else: # second call: only 2 samples had tool calls + completion_ids = torch.tensor( + [ + # 'Done!<|im_end|>' + [17453, 0, 151645], + # 'Done!<|im_end|>' + [17453, 0, 151645], + ], + device=input_ids.device, + ) + return torch.cat([input_ids, completion_ids], dim=-1) + + with patch.object(trainer.model, "generate", side_effect=fake_generate): + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["tools/call_frequency"] == pytest.approx(2 / 3) + assert trainer.state.log_history[-1]["tools/failure_frequency"] == pytest.approx(1 / 2) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0"), + reason="Tool parsing is not supported in transformers versions below 5.0.0", + strict=True, + ) + @require_jmespath + def test_training_with_tools_column_subset(self): + """Verify that the model only sees the subset of tools declared per sample. + + We have two tools (multiply_tool, add_tool) but each sample only allows one. + Because the model's fake output calls multiply_tool, the sample that only allows add_tool + should fail the tool call (tool not found) — proving the filtering works. + """ + dataset = Dataset.from_dict( + { + "prompt": [ + [{"role": "user", "content": "What is 3 times 4?"}], + [{"role": "user", "content": "Add 1 and 2."}], + [{"role": "user", "content": "Tell me a joke."}], + ], + "tools": [ + ["multiply_tool"], # sample 0: only multiply allowed + ["add_tool"], # sample 1: only add allowed → multiply_tool call will fail ("not found") + ["multiply_tool", "add_tool"], # sample 2: both allowed, but model won't call any + ], + } + ) + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=128, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + tools=[multiply_tool, add_tool], + ) + + def fake_generate(input_ids, **kwargs): + if input_ids.shape[0] == 3: # first call + # fmt: off + completion_ids = torch.tensor( + [ + # Sample 0: calls multiply_tool → should succeed (multiply_tool is allowed) + # '\n{"name": "multiply_tool", "arguments": {"a": 3, "b": 4}}\n<|im_end|>' + [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645], + # Sample 1: calls multiply_tool → should FAIL (only add_tool is allowed) + # '\n{"name": "multiply_tool", "arguments": {"a": 3, "b": 4}}\n<|im_end|>' + [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645], + # Sample 2: plain text, no tool call + # "I don't know any tool<|im_end|>" + [40, 1513, 944, 1414, 894, 5392, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643], + ], + device=input_ids.device, + ) + # fmt: on + else: # second call: 2 samples had tool calls + completion_ids = torch.tensor( + [ + [17453, 0, 151645], # 'Done!<|im_end|>' + [17453, 0, 151645], # 'Done!<|im_end|>' + ], + device=input_ids.device, + ) + return torch.cat([input_ids, completion_ids], dim=-1) + + with patch.object(trainer.model, "generate", side_effect=fake_generate): + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + # 2 out of 3 samples per batch made tool calls + assert trainer.state.log_history[-1]["tools/call_frequency"] == pytest.approx(2 / 3) + # At least some calls should fail because sample 1 only allows add_tool but model calls multiply_tool. + # The exact failure rate depends on batch composition across epochs/steps. + assert trainer.state.log_history[-1]["tools/failure_frequency"] > 0 + + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0"), + reason="Tool parsing is not supported in transformers versions below 5.0.0", + strict=True, + ) + @require_jmespath + def test_training_without_tools_column_backward_compat(self): + """Training with tools but without a `tools` dataset column should work exactly as before.""" + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=128, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + tools=[multiply_tool], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + def fake_generate(input_ids, **kwargs): + if input_ids.shape[0] == 3: + # fmt: off + completion_ids = torch.tensor( + [ + # '\n{"name": "multiply_tool", "arguments": {"a": 3, "b": 4}}\n<|im_end|>' + [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645], + # invalid tool call: wrong argument name + [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 66, 788, 220, 19, 11248, 151658, 151645], + # "I don't know any tool<|im_end|>" + [40, 1513, 944, 1414, 894, 5392, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643], + ], + device=input_ids.device, + ) + # fmt: on + else: + completion_ids = torch.tensor( + [[17453, 0, 151645], [17453, 0, 151645]], + device=input_ids.device, + ) + return torch.cat([input_ids, completion_ids], dim=-1) + + with patch.object(trainer.model, "generate", side_effect=fake_generate): + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["tools/call_frequency"] == pytest.approx(2 / 3) + assert trainer.state.log_history[-1]["tools/failure_frequency"] == pytest.approx(1 / 2) + + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0"), + reason="Tool parsing is not supported in transformers versions below 5.0.0", + strict=True, + ) + @require_jmespath + def test_training_with_none_in_tools_column_falls_back(self): + """When a sample's tools column is None, fall back to the full global tools list.""" + dataset = Dataset.from_dict( + { + "prompt": [ + [{"role": "user", "content": "What is 3 times 4?"}], + [{"role": "user", "content": "What is 5 plus 6?"}], + [{"role": "user", "content": "Tell me a joke."}], + ], + "tools": [ + ["multiply_tool"], # only multiply + None, # fallback to all tools (multiply + add) + ["multiply_tool", "add_tool"], # both tools (fake_generate calls multiply_tool for all batches) + ], + } + ) + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=128, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + tools=[multiply_tool, add_tool], + ) + + def fake_generate(input_ids, **kwargs): + if input_ids.shape[0] == 3: + # fmt: off + completion_ids = torch.tensor( + [ + # Sample 0: calls multiply_tool → ok (allowed) + [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645], + # Sample 1: calls multiply_tool → ok (None fallback = all tools) + [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645], + # Sample 2: plain text + [40, 1513, 944, 1414, 894, 5392, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643], + ], + device=input_ids.device, + ) + # fmt: on + else: + completion_ids = torch.tensor( + [[17453, 0, 151645], [17453, 0, 151645]], + device=input_ids.device, + ) + return torch.cat([input_ids, completion_ids], dim=-1) + + with patch.object(trainer.model, "generate", side_effect=fake_generate): + trainer.train() + + # Both tool calls should succeed (sample 0 has multiply allowed, sample 1 has all tools via None fallback) + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["tools/call_frequency"] == pytest.approx(2 / 3) + # No failures: both samples' tool calls are to multiply_tool which is in their allowed set + assert trainer.state.log_history[-1]["tools/failure_frequency"] == pytest.approx(0.0) + + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0"), + reason="Tool parsing is not supported in transformers versions below 5.0.0", + strict=True, + ) + @require_jmespath + def test_training_with_async_tool_and_tools_column(self): + """Verify that async tools also work with per-sample filtering.""" + dataset = Dataset.from_dict( + { + "prompt": [ + [{"role": "user", "content": "What is 3 times 4?"}], + [{"role": "user", "content": "What is 5 plus 6?"}], + [{"role": "user", "content": "Tell me a joke."}], + ], + "tools": [ + ["multiply_tool"], + ["multiply_tool"], + ["multiply_tool"], + ], + } + ) + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=128, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + tools=[multiply_tool, async_add_tool], # mix of sync and async in global pool + ) + + def fake_generate(input_ids, **kwargs): + if input_ids.shape[0] == 3: + # fmt: off + completion_ids = torch.tensor( + [ + [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645], + [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 66, 788, 220, 19, 11248, 151658, 151645], + [40, 1513, 944, 1414, 894, 5392, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643], + ], + device=input_ids.device, + ) + # fmt: on + else: + completion_ids = torch.tensor( + [[17453, 0, 151645], [17453, 0, 151645]], + device=input_ids.device, + ) + return torch.cat([input_ids, completion_ids], dim=-1) + + with patch.object(trainer.model, "generate", side_effect=fake_generate): + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.2.0"), + reason="Environment factory support is not available in transformers versions below 5.2.0", + strict=True, + ) + @require_jmespath + @patch.dict(os.environ, {"TRL_EXPERIMENTAL_SILENCE": "1"}) + def test_training_with_tools_column_and_environment_factory(self): + """Verify training completes when tools column is combined with environment_factory. + + Note: per-sample tool filtering resolves callables from self.tools, which only holds + environment 0's bound methods. All tool calls therefore dispatch to environment 0. + This is a known limitation documented in _generate_and_score_completions. + """ + dataset = Dataset.from_dict( + { + "prompt": [ + [{"role": "user", "content": "Increment by 1."}], + [{"role": "user", "content": "Increment by 1."}], + [{"role": "user", "content": "Tell me a joke."}], + ], + "tools": [ + ["increment"], + ["increment"], + ["increment"], + ], + } + ) + + class DummyEnvironment: + def reset(self, **kwargs): + self._counter = 0 + + def increment(self, step: int) -> int: + """ + Increment the internal counter. + + Args: + step: Value to add to the counter. + + Returns: + The updated counter value. + """ + self._counter += step + return self._counter + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=128, + report_to="none", + ) + + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + environment_factory=DummyEnvironment, + ) + + def fake_generate(input_ids, **kwargs): + if input_ids.shape[0] == 3: # first call + # fmt: off + completion_ids = torch.tensor( + [ + # Sample 0: '\n{"name": "increment", "arguments": {"step": 1}}\n<|im_end|>' + [151657, 198, 4913, 606, 788, 330, 35744, 497, 330, 16370, 788, 5212, 9520, 788, 220, 16, 11248, 151658, 151645, 151643], + # Sample 1: '\n{"name": "increment", "arguments": {"step": 1}}\n<|im_end|>' + [151657, 198, 4913, 606, 788, 330, 35744, 497, 330, 16370, 788, 5212, 9520, 788, 220, 16, 11248, 151658, 151645, 151643], + # Sample 2: "I won't increment<|im_end|>" + [40, 2765, 944, 16252, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643], + ], + device=input_ids.device, + ) + # fmt: on + else: # second call: 2 samples had tool calls + completion_ids = torch.tensor( + [ + [17453, 0, 151645], # 'Done!<|im_end|>' + [17453, 0, 151645], # 'Done!<|im_end|>' + ], + device=input_ids.device, + ) + return torch.cat([input_ids, completion_ids], dim=-1) + + with patch.object(trainer.model, "generate", side_effect=fake_generate): + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["tools/call_frequency"] == pytest.approx(2 / 3) + assert trainer.state.log_history[-1]["tools/failure_frequency"] == pytest.approx(0.0) + # All tool calls resolve to environment 0's bound methods (known limitation): + # per_sample_tools is built from self.tools which only holds environment 0's callables. + assert trainer.environments[0]._counter == 2 + assert trainer.environments[1]._counter == 0 + assert trainer.environments[2]._counter == 0 diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index c5eed094192..30297e6c239 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -226,7 +226,9 @@ class GRPOTrainer(_BaseTrainer): Google-style docstring describing its purpose, arguments, and return value. For more details, see: https://huggingface.co/docs/transformers/en/chat_extras#passing-tools. The model uses the function's name, type hints, and docstring to determine how to call it. Ensure that the model's chat template supports tool - use and that it has been fine-tuned for tool calling. + use and that it has been fine-tuned for tool calling. If the dataset contains a `tools` column with a + list of tool names per sample, only those tools are exposed during that sample's rollout. If the column + value is missing or `None` for a sample, the full `tools` list is used. rollout_func (`RolloutFunc`, *optional*): Function to use for generating completions. It receives the list of prompts allocated to the current process and the trainer instance. It must return a dict with `"prompt_ids"`, `"completion_ids"`, and @@ -625,6 +627,7 @@ def __init__( # Reference model self.beta = args.beta + if self.beta == 0.0: # If beta is 0.0, the reference model is not needed self.ref_model = None @@ -850,7 +853,7 @@ def _set_signature_columns_if_needed(self): # and "attention_mask"). In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't # work. Instead, we set them to the columns expected by the `training_step` method, hence the override. if self._signature_columns is None: - self._signature_columns = ["prompt", "image", "images"] + self._signature_columns = ["prompt", "image", "images", "tools"] # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy. # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an @@ -1239,8 +1242,15 @@ async def _run_async_funcs(): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _tokenize_prompts(self, prompts: list): - """Tokenize prompts and extract images/multimodal fields for generation.""" + def _tokenize_prompts(self, prompts: list, per_sample_tools: "list[list[Callable]] | None" = None): + """Tokenize prompts and extract images/multimodal fields for generation. + + Args: + prompts: List of prompts (conversational or plain text). + per_sample_tools: Optional list of per-sample tool lists. When provided, each prompt is tokenized + individually with its own tool list (needed because different samples may expose different tools + in the chat template). When ``None``, all prompts are batch-tokenized with ``self.tools``. + """ if is_conversational({"prompt": prompts[0]}): # Extract images from messages for VLM support images = [] @@ -1256,26 +1266,58 @@ def _tokenize_prompts(self, prompts: list): images.append(prompt_images if prompt_images else None) images = images if has_images else None - # We pass padding=True to work around a bug introduced in transformers 5.2.0 in some processors - # (e.g. Qwen2.5-VL) that crash on batched unpadded input. We then unpad input_ids using attention_mask. - # See: https://github.com/huggingface/transformers/issues/44514 - tokenized = self.processing_class.apply_chat_template( - conversation=prompts, - tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] - chat_template=self.chat_template, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - padding=True, - **self.chat_template_kwargs, - ) - # Unpad input_ids: remove padding tokens using attention_mask to get per-sequence lists - prompt_ids = [ - [tok for tok, m in zip(ids, mask, strict=True) if m] - for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True) - ] - # For VLMs, the processor returns extra multimodal fields (pixel_values, image_grid_thw, etc.) - multimodal_fields = {k: v for k, v in tokenized.items() if k not in ("input_ids", "attention_mask")} + # When per_sample_tools is provided, we must tokenize each prompt individually because + # apply_chat_template doesn't support varying tools across a batch. + if per_sample_tools is not None: + prompt_ids = [] + all_multimodal_fields = {} + for i, (prompt, tools_for_sample) in enumerate(zip(prompts, per_sample_tools, strict=True)): + tokenized = self.processing_class.apply_chat_template( + conversation=[prompt], + tools=tools_for_sample + or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] + chat_template=self.chat_template, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding=False, + **self.chat_template_kwargs, + ) + prompt_ids.append(tokenized["input_ids"][0]) + for k, v in tokenized.items(): + if k not in ("input_ids", "attention_mask"): + # apply_chat_template with conversation=[prompt] returns tensors with a leading + # batch-of-1 dimension. Index with [0] for both lists and tensors so that + # torch.stack below produces shape (N, ...) not (N, 1, ...). + all_multimodal_fields.setdefault(k, []).append( + v[0] if isinstance(v, (list, torch.Tensor)) else v + ) + # Normalize tensor fields so they are batched consistently with the non-per-sample path. + for k, values in all_multimodal_fields.items(): + if values and isinstance(values[0], torch.Tensor): + all_multimodal_fields[k] = torch.stack(values, dim=0) + multimodal_fields = all_multimodal_fields + else: + # We pass padding=True to work around a bug introduced in transformers 5.2.0 in some processors + # (e.g. Qwen2.5-VL) that crash on batched unpadded input. We then unpad input_ids using attention_mask. + # See: https://github.com/huggingface/transformers/issues/44514 + tokenized = self.processing_class.apply_chat_template( + conversation=prompts, + tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] + chat_template=self.chat_template, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding=True, + **self.chat_template_kwargs, + ) + # Unpad input_ids: remove padding tokens using attention_mask to get per-sequence lists + prompt_ids = [ + [tok for tok, m in zip(ids, mask, strict=True) if m] + for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True) + ] + # For VLMs, the processor returns extra multimodal fields (pixel_values, image_grid_thw, etc.) + multimodal_fields = {k: v for k, v in tokenized.items() if k not in ("input_ids", "attention_mask")} else: prompt_ids = self.processing_class(text=prompts)["input_ids"] images = None @@ -1407,8 +1449,37 @@ def _get_tool_suffix_ids(self, tool_messages): return full_ids[len(prefix_ids) :] - def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields): + def _tool_call_loop( + self, + prompts, + prompt_ids, + completion_ids, + completions, + logprobs, + images, + multimodal_fields, + per_sample_tools=None, + ): # Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt + + # Build per-sample sync/async tool dicts from the provided callables. + if per_sample_tools is not None: + sync_tool_dicts = [] + async_tool_dicts = [] + for sample_tools in per_sample_tools: + sync_dict = {} + async_dict = {} + for tool in sample_tools: + if inspect.iscoroutinefunction(tool): + async_dict[tool.__name__] = tool + else: + sync_dict[tool.__name__] = tool + sync_tool_dicts.append(sync_dict) + async_tool_dicts.append(async_dict) + else: + sync_tool_dicts = self._sync_tool_dicts + async_tool_dicts = self._async_tool_dicts + tool_calls = [completion[0].get("tool_calls") for completion in completions] idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call] tool_calls = [tool_calls[idx] for idx in idxs_with_tool] @@ -1424,8 +1495,8 @@ def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logp idx_with_tool = idxs_with_tool[idx] tool_call_list = tool_calls[idx] prompt_completion_tool = prompt_completion_tools[idx] - sync_tool_dict = self._sync_tool_dicts[idx_with_tool] - async_tool_dict = self._async_tool_dicts[idx_with_tool] + sync_tool_dict = sync_tool_dicts[idx_with_tool] + async_tool_dict = async_tool_dicts[idx_with_tool] # Append the last assistant message (which triggered tool_calls) to the prompt prompt_completion_tool.append(completions[idx_with_tool][-1]) async_coros = [] @@ -1586,7 +1657,7 @@ async def _run_async_tools(async_coros): return tool_mask, completions, completion_ids, logprobs, tool_call_count, tool_failure_count - def _generate(self, prompts: list): + def _generate(self, prompts: list, per_sample_tools: "list[list[Callable]] | None" = None): device = self.accelerator.device mode = "train" if self.model.training else "eval" @@ -1612,7 +1683,7 @@ def _generate(self, prompts: list): extra_fields = {k: v for k, v in output.items() if k not in required_keys} prompt_ids, completion_ids, logprobs = output["prompt_ids"], output["completion_ids"], output["logprobs"] else: - prompt_ids, images, multimodal_fields = self._tokenize_prompts(prompts) + prompt_ids, images, multimodal_fields = self._tokenize_prompts(prompts, per_sample_tools=per_sample_tools) completion_ids, logprobs = self._generate_single_turn(prompt_ids, images, multimodal_fields) extra_fields = {} @@ -1641,7 +1712,14 @@ def _generate(self, prompts: list): tool_call_count, tool_failure_count, ) = self._tool_call_loop( - prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields + prompts, + prompt_ids, + completion_ids, + completions, + logprobs, + images, + multimodal_fields, + per_sample_tools=per_sample_tools, ) else: # Support custom env_mask from rollout_func (e.g., for environment feedback masking) @@ -1742,6 +1820,23 @@ def _generate_and_score_completions( for prompt, image_list in zip(prompts, images, strict=True) ] + # If the dataset has a "tools" column, use it for per-sample tool filtering. + # Each sample lists the tool names it wants; missing/None falls back to all tools. + # Note: combining a `tools` column with `environment_factory` is not supported. When + # environments are used, `self.tools` holds environment 0's bound methods (representative + # for schema rendering), so all per-sample callables would resolve to environment 0's + # instance instead of each sample's own environment. + per_sample_tools = None + if self.tools and inputs and "tools" in inputs[0]: + tool_lookup = {tool.__name__: tool for tool in self.tools} + per_sample_tools = [] + for example in inputs: + tool_names = example.get("tools") + if tool_names is not None: + per_sample_tools.append([tool_lookup[name] for name in tool_names]) + else: + per_sample_tools.append(list(tool_lookup.values())) + ( prompt_ids_list, completion_ids_list, @@ -1750,7 +1845,7 @@ def _generate_and_score_completions( num_items_in_batch, sampling_per_token_logps_list, extra_fields, - ) = self._generate(prompts) + ) = self._generate(prompts, per_sample_tools=per_sample_tools) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list]