diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 624ba22b91b..054a6c286b0 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -39,6 +39,8 @@
title: Distributing Training
- local: use_model
title: Using Trained Models
+ - local: per_sample_tools
+ title: Per-Sample Tool Filtering
title: How-to guides
- sections:
- local: deepspeed_integration
diff --git a/docs/source/per_sample_tools.md b/docs/source/per_sample_tools.md
new file mode 100644
index 00000000000..00a6bb9c892
--- /dev/null
+++ b/docs/source/per_sample_tools.md
@@ -0,0 +1,108 @@
+# Per-Sample Tool Filtering in GRPOTrainer
+
+## Motivation
+
+In many agentic settings, different training samples require different subsets of tools.
+For example, a math question might only need a `calculator`, while a translation task
+might only need a `translator`. Exposing all tools on every sample can confuse the model
+and dilute the training signal.
+
+`GRPOTrainer` automatically detects a `tools` column in your dataset and uses it to restrict
+**which tools are available per sample**, drawn from the global `tools` pool.
+
+## How It Works
+
+1. **Global tool pool** — You pass the full set of tools to the trainer via `tools=[...]` as before.
+2. **Per-sample tool column** — Your dataset includes a `"tools"` column containing a list of tool
+ **names** (strings matching `tool.__name__`) allowed for each sample.
+3. **Automatic filtering** — For each rollout, only the specified tools appear in the model's
+ system prompt (chat template) and are available for execution. If the column value is `None`
+ for a sample, all tools are used as a fallback.
+
+## Example
+
+```python
+from datasets import Dataset
+from trl import GRPOTrainer, GRPOConfig
+
+
+# Define tool functions
+def calculator(number_a: float, operation: str, number_b: float) -> str:
+ """Perform a basic arithmetic operation on two numbers.
+
+ Args:
+ number_a: The first operand.
+ operation: The operation to perform. One of '+', '-', '*', '/'.
+ number_b: The second operand.
+
+ Returns:
+ The result of the operation as a string.
+
+ Raises:
+ ValueError: If the operation is not supported or division by zero is attempted.
+ """
+ try:
+ number_a = float(number_a)
+ except (TypeError, ValueError):
+ raise TypeError(f"number_a must be convertible to a number, got {type(number_a).__name__!r}")
+ try:
+ number_b = float(number_b)
+ except (TypeError, ValueError):
+ raise TypeError(f"number_b must be convertible to a number, got {type(number_b).__name__!r}")
+ if operation == "+":
+ return str(number_a + number_b)
+ elif operation == "-":
+ return str(number_a - number_b)
+ elif operation == "*":
+ return str(number_a * number_b)
+ elif operation == "/":
+ if number_b == 0:
+ raise ValueError("Division by zero is not allowed.")
+ return str(number_a / number_b)
+ else:
+ raise ValueError(f"Unsupported operation '{operation}'. Use one of: +, -, *, /")
+
+
+def translator(text: str, target_language: str) -> str:
+ """Translate text to a target language.
+
+ Args:
+ text: The text to translate.
+ target_language: ISO language code, e.g. 'fr', 'es', 'de'.
+
+ Returns:
+ The translated text.
+ """
+ # Placeholder — in practice, call a translation API
+ return f"[{target_language}] {text}"
+
+
+# Build dataset with per-sample tool column
+dataset = Dataset.from_dict({
+ "prompt": [
+ [{"role": "user", "content": "What is 123 * 456?"}],
+ [{"role": "user", "content": "Translate 'good morning' to French."}],
+ [{"role": "user", "content": "Compute 2^10 and translate the result to Spanish."}],
+ ],
+ "tools": [
+ ["calculator"], # only calculator available
+ ["translator"], # only translator available
+ ["calculator", "translator"], # both available
+ ],
+})
+
+# The trainer automatically detects the "tools" column and applies per-sample filtering
+trainer = GRPOTrainer(
+ model="Qwen/Qwen2.5-0.5B-Instruct",
+ reward_funcs=my_reward,
+ tools=[calculator, translator],
+ train_dataset=dataset,
+)
+
+trainer.train()
+```
+
+## Backward Compatibility
+
+When the dataset has no `tools` column, behavior is identical to the existing API — all
+tools in the `tools` list are used for every sample.
diff --git a/tests/test_grpo_tools_column.py b/tests/test_grpo_tools_column.py
new file mode 100644
index 00000000000..651cbe84cb6
--- /dev/null
+++ b/tests/test_grpo_tools_column.py
@@ -0,0 +1,660 @@
+# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Tests for the per-sample tool filtering feature in GRPOTrainer.
+
+When a dataset contains a `tools` column, GRPOTrainer automatically uses it to restrict
+which tools each sample can call during rollout.
+"""
+
+import os
+from unittest.mock import patch
+
+import pytest
+import torch
+import transformers
+from datasets import Dataset, load_dataset
+from packaging.version import Version
+
+from trl import GRPOConfig, GRPOTrainer
+
+from .testing_utils import TrlTestCase, require_jmespath
+
+
+# ──────────────────────────────────────────────────────────────────────
+# Tool definitions
+# ──────────────────────────────────────────────────────────────────────
+
+
+def multiply_tool(a: int, b: int) -> int:
+ """
+ Multiplies two integers.
+
+ Args:
+ a: The first integer.
+ b: The second integer.
+
+ Returns:
+ The product of the two integers.
+ """
+ return a * b
+
+
+def add_tool(a: int, b: int) -> int:
+ """
+ Adds two integers.
+
+ Args:
+ a: The first integer.
+ b: The second integer.
+
+ Returns:
+ The sum of the two integers.
+ """
+ return a + b
+
+
+async def async_add_tool(a: int, b: int) -> int:
+ """
+ Asynchronously adds two integers.
+
+ Args:
+ a: The first integer.
+ b: The second integer.
+
+ Returns:
+ The sum of the two integers.
+ """
+ return a + b
+
+
+# ──────────────────────────────────────────────────────────────────────
+# Unit-level tests (no model loading, fast)
+# ──────────────────────────────────────────────────────────────────────
+
+
+class TestToolsColumnValidation(TrlTestCase):
+ """Test that per-sample tool filtering via the `tools` dataset column works correctly."""
+
+ def _make_conversational_dataset(self, tool_names_per_sample):
+ """Helper: create a minimal conversational dataset with a tools column."""
+ prompts = [[{"role": "user", "content": f"Question {i}"}] for i in range(len(tool_names_per_sample))]
+ return Dataset.from_dict({"prompt": prompts, "tools": tool_names_per_sample})
+
+ @pytest.mark.xfail(
+ condition=Version(transformers.__version__) < Version("5.0.0"),
+ reason="Tool parsing is not supported in transformers versions below 5.0.0",
+ strict=True,
+ )
+ @require_jmespath
+ def test_valid_tools_column_passes_validation(self):
+ """A dataset with a tools column and matching tool pool should init without errors."""
+ dataset = self._make_conversational_dataset([["multiply_tool"], ["add_tool"], ["multiply_tool", "add_tool"]])
+ trainer = GRPOTrainer(
+ model="trl-internal-testing/tiny-Qwen3MoeForCausalLM",
+ reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ args=GRPOConfig(
+ output_dir=self.tmp_dir,
+ report_to="none",
+ per_device_train_batch_size=3,
+ num_generations=3,
+ ),
+ train_dataset=dataset,
+ tools=[multiply_tool, add_tool],
+ )
+ assert {t.__name__ for t in trainer.tools} == {"multiply_tool", "add_tool"}
+
+ @pytest.mark.xfail(
+ condition=Version(transformers.__version__) < Version("5.0.0"),
+ reason="Tool parsing is not supported in transformers versions below 5.0.0",
+ strict=True,
+ )
+ @require_jmespath
+ def test_no_tools_column_backward_compat(self):
+ """When the dataset has no tools column, trainer behaves as before (all tools used)."""
+ dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")
+ trainer = GRPOTrainer(
+ model="trl-internal-testing/tiny-Qwen3MoeForCausalLM",
+ reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ args=GRPOConfig(
+ output_dir=self.tmp_dir,
+ report_to="none",
+ per_device_train_batch_size=3,
+ num_generations=3,
+ ),
+ train_dataset=dataset,
+ tools=[multiply_tool],
+ )
+ assert {t.__name__ for t in trainer.tools} == {"multiply_tool"}
+
+ @pytest.mark.xfail(
+ condition=Version(transformers.__version__) < Version("5.0.0"),
+ reason="Tool parsing is not supported in transformers versions below 5.0.0",
+ strict=True,
+ )
+ @require_jmespath
+ def test_tools_accessible_per_sample(self):
+ """Each per-sample tool dict should contain the correct callables for its sample."""
+ dataset = self._make_conversational_dataset([["multiply_tool", "add_tool"]])
+ trainer = GRPOTrainer(
+ model="trl-internal-testing/tiny-Qwen3MoeForCausalLM",
+ reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ args=GRPOConfig(
+ output_dir=self.tmp_dir,
+ report_to="none",
+ per_device_train_batch_size=3,
+ num_generations=3,
+ ),
+ train_dataset=dataset,
+ tools=[multiply_tool, add_tool],
+ )
+ assert trainer._sync_tool_dicts[0]["multiply_tool"] is multiply_tool
+ assert trainer._sync_tool_dicts[0]["add_tool"] is add_tool
+
+ @pytest.mark.xfail(
+ condition=Version(transformers.__version__) < Version("5.0.0"),
+ reason="Tool parsing is not supported in transformers versions below 5.0.0",
+ strict=True,
+ )
+ @require_jmespath
+ def test_signature_columns_include_tools(self):
+ """The `tools` column should always be included in _signature_columns."""
+ dataset = self._make_conversational_dataset([["multiply_tool"]])
+ trainer = GRPOTrainer(
+ model="trl-internal-testing/tiny-Qwen3MoeForCausalLM",
+ reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ args=GRPOConfig(
+ output_dir=self.tmp_dir,
+ report_to="none",
+ per_device_train_batch_size=3,
+ num_generations=3,
+ ),
+ train_dataset=dataset,
+ tools=[multiply_tool],
+ )
+ trainer._set_signature_columns_if_needed()
+ assert "tools" in trainer._signature_columns
+
+
+# ──────────────────────────────────────────────────────────────────────
+# Integration tests (model loading + training with fake_generate)
+# ──────────────────────────────────────────────────────────────────────
+
+
+class TestToolsColumnTraining(TrlTestCase):
+ """End-to-end tests for training with per-sample tool filtering."""
+
+ @pytest.mark.xfail(
+ condition=Version(transformers.__version__) < Version("5.0.0"),
+ reason="Tool parsing is not supported in transformers versions below 5.0.0",
+ strict=True,
+ )
+ @require_jmespath
+ def test_training_with_tools_column(self):
+ """Train with a `tools` dataset column and verify per-sample filtering works end-to-end.
+
+ We create a 3-sample dataset where:
+ - Sample 0: only multiply_tool available → model calls multiply_tool (succeeds)
+ - Sample 1: only multiply_tool available → model calls multiply_tool (fails, wrong arg)
+ - Sample 2: only multiply_tool available → model returns plain text (no tool call)
+ """
+ dataset = Dataset.from_dict(
+ {
+ "prompt": [
+ [{"role": "user", "content": "What is 3 times 4?"}],
+ [{"role": "user", "content": "Multiply 5 and 6."}],
+ [{"role": "user", "content": "Tell me a joke."}],
+ ],
+ "tools": [
+ ["multiply_tool"],
+ ["multiply_tool"],
+ ["multiply_tool"],
+ ],
+ }
+ )
+
+ training_args = GRPOConfig(
+ output_dir=self.tmp_dir,
+ learning_rate=0.1,
+ per_device_train_batch_size=3,
+ num_generations=3,
+ max_completion_length=128,
+ report_to="none",
+ )
+ trainer = GRPOTrainer(
+ model="trl-internal-testing/tiny-Qwen3MoeForCausalLM",
+ reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ args=training_args,
+ train_dataset=dataset,
+ tools=[multiply_tool, add_tool], # global pool has both; dataset column restricts per sample
+ )
+
+ previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+ def fake_generate(input_ids, **kwargs):
+ if input_ids.shape[0] == 3: # first call
+ # fmt: off
+ completion_ids = torch.tensor(
+ [
+ # '\n{"name": "multiply_tool", "arguments": {"a": 3, "b": 4}}\n<|im_end|>'
+ [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645],
+ # invalid tool call: wrong argument name "c" instead of "b"
+ # '\n{"name": "multiply_tool", "arguments": {"a": 3, "c": 4}}\n<|im_end|>'
+ [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 66, 788, 220, 19, 11248, 151658, 151645],
+ # "I don't know any tool<|im_end|>"
+ [40, 1513, 944, 1414, 894, 5392, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643],
+ ],
+ device=input_ids.device,
+ )
+ # fmt: on
+ else: # second call: only 2 samples had tool calls
+ completion_ids = torch.tensor(
+ [
+ # 'Done!<|im_end|>'
+ [17453, 0, 151645],
+ # 'Done!<|im_end|>'
+ [17453, 0, 151645],
+ ],
+ device=input_ids.device,
+ )
+ return torch.cat([input_ids, completion_ids], dim=-1)
+
+ with patch.object(trainer.model, "generate", side_effect=fake_generate):
+ trainer.train()
+
+ assert trainer.state.log_history[-1]["train_loss"] is not None
+ assert trainer.state.log_history[-1]["tools/call_frequency"] == pytest.approx(2 / 3)
+ assert trainer.state.log_history[-1]["tools/failure_frequency"] == pytest.approx(1 / 2)
+
+ # Check that the params have changed
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
+
+ @pytest.mark.xfail(
+ condition=Version(transformers.__version__) < Version("5.0.0"),
+ reason="Tool parsing is not supported in transformers versions below 5.0.0",
+ strict=True,
+ )
+ @require_jmespath
+ def test_training_with_tools_column_subset(self):
+ """Verify that the model only sees the subset of tools declared per sample.
+
+ We have two tools (multiply_tool, add_tool) but each sample only allows one.
+ Because the model's fake output calls multiply_tool, the sample that only allows add_tool
+ should fail the tool call (tool not found) — proving the filtering works.
+ """
+ dataset = Dataset.from_dict(
+ {
+ "prompt": [
+ [{"role": "user", "content": "What is 3 times 4?"}],
+ [{"role": "user", "content": "Add 1 and 2."}],
+ [{"role": "user", "content": "Tell me a joke."}],
+ ],
+ "tools": [
+ ["multiply_tool"], # sample 0: only multiply allowed
+ ["add_tool"], # sample 1: only add allowed → multiply_tool call will fail ("not found")
+ ["multiply_tool", "add_tool"], # sample 2: both allowed, but model won't call any
+ ],
+ }
+ )
+
+ training_args = GRPOConfig(
+ output_dir=self.tmp_dir,
+ learning_rate=0.1,
+ per_device_train_batch_size=3,
+ num_generations=3,
+ max_completion_length=128,
+ report_to="none",
+ )
+ trainer = GRPOTrainer(
+ model="trl-internal-testing/tiny-Qwen3MoeForCausalLM",
+ reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ args=training_args,
+ train_dataset=dataset,
+ tools=[multiply_tool, add_tool],
+ )
+
+ def fake_generate(input_ids, **kwargs):
+ if input_ids.shape[0] == 3: # first call
+ # fmt: off
+ completion_ids = torch.tensor(
+ [
+ # Sample 0: calls multiply_tool → should succeed (multiply_tool is allowed)
+ # '\n{"name": "multiply_tool", "arguments": {"a": 3, "b": 4}}\n<|im_end|>'
+ [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645],
+ # Sample 1: calls multiply_tool → should FAIL (only add_tool is allowed)
+ # '\n{"name": "multiply_tool", "arguments": {"a": 3, "b": 4}}\n<|im_end|>'
+ [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645],
+ # Sample 2: plain text, no tool call
+ # "I don't know any tool<|im_end|>"
+ [40, 1513, 944, 1414, 894, 5392, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643],
+ ],
+ device=input_ids.device,
+ )
+ # fmt: on
+ else: # second call: 2 samples had tool calls
+ completion_ids = torch.tensor(
+ [
+ [17453, 0, 151645], # 'Done!<|im_end|>'
+ [17453, 0, 151645], # 'Done!<|im_end|>'
+ ],
+ device=input_ids.device,
+ )
+ return torch.cat([input_ids, completion_ids], dim=-1)
+
+ with patch.object(trainer.model, "generate", side_effect=fake_generate):
+ trainer.train()
+
+ assert trainer.state.log_history[-1]["train_loss"] is not None
+ # 2 out of 3 samples per batch made tool calls
+ assert trainer.state.log_history[-1]["tools/call_frequency"] == pytest.approx(2 / 3)
+ # At least some calls should fail because sample 1 only allows add_tool but model calls multiply_tool.
+ # The exact failure rate depends on batch composition across epochs/steps.
+ assert trainer.state.log_history[-1]["tools/failure_frequency"] > 0
+
+ @pytest.mark.xfail(
+ condition=Version(transformers.__version__) < Version("5.0.0"),
+ reason="Tool parsing is not supported in transformers versions below 5.0.0",
+ strict=True,
+ )
+ @require_jmespath
+ def test_training_without_tools_column_backward_compat(self):
+ """Training with tools but without a `tools` dataset column should work exactly as before."""
+ dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")
+
+ training_args = GRPOConfig(
+ output_dir=self.tmp_dir,
+ learning_rate=0.1,
+ per_device_train_batch_size=3,
+ num_generations=3,
+ max_completion_length=128,
+ report_to="none",
+ )
+ trainer = GRPOTrainer(
+ model="trl-internal-testing/tiny-Qwen3MoeForCausalLM",
+ reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ args=training_args,
+ train_dataset=dataset,
+ tools=[multiply_tool],
+ )
+
+ previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+ def fake_generate(input_ids, **kwargs):
+ if input_ids.shape[0] == 3:
+ # fmt: off
+ completion_ids = torch.tensor(
+ [
+ # '\n{"name": "multiply_tool", "arguments": {"a": 3, "b": 4}}\n<|im_end|>'
+ [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645],
+ # invalid tool call: wrong argument name
+ [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 66, 788, 220, 19, 11248, 151658, 151645],
+ # "I don't know any tool<|im_end|>"
+ [40, 1513, 944, 1414, 894, 5392, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643],
+ ],
+ device=input_ids.device,
+ )
+ # fmt: on
+ else:
+ completion_ids = torch.tensor(
+ [[17453, 0, 151645], [17453, 0, 151645]],
+ device=input_ids.device,
+ )
+ return torch.cat([input_ids, completion_ids], dim=-1)
+
+ with patch.object(trainer.model, "generate", side_effect=fake_generate):
+ trainer.train()
+
+ assert trainer.state.log_history[-1]["train_loss"] is not None
+ assert trainer.state.log_history[-1]["tools/call_frequency"] == pytest.approx(2 / 3)
+ assert trainer.state.log_history[-1]["tools/failure_frequency"] == pytest.approx(1 / 2)
+
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
+
+ @pytest.mark.xfail(
+ condition=Version(transformers.__version__) < Version("5.0.0"),
+ reason="Tool parsing is not supported in transformers versions below 5.0.0",
+ strict=True,
+ )
+ @require_jmespath
+ def test_training_with_none_in_tools_column_falls_back(self):
+ """When a sample's tools column is None, fall back to the full global tools list."""
+ dataset = Dataset.from_dict(
+ {
+ "prompt": [
+ [{"role": "user", "content": "What is 3 times 4?"}],
+ [{"role": "user", "content": "What is 5 plus 6?"}],
+ [{"role": "user", "content": "Tell me a joke."}],
+ ],
+ "tools": [
+ ["multiply_tool"], # only multiply
+ None, # fallback to all tools (multiply + add)
+ ["multiply_tool", "add_tool"], # both tools (fake_generate calls multiply_tool for all batches)
+ ],
+ }
+ )
+
+ training_args = GRPOConfig(
+ output_dir=self.tmp_dir,
+ learning_rate=0.1,
+ per_device_train_batch_size=3,
+ num_generations=3,
+ max_completion_length=128,
+ report_to="none",
+ )
+ trainer = GRPOTrainer(
+ model="trl-internal-testing/tiny-Qwen3MoeForCausalLM",
+ reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ args=training_args,
+ train_dataset=dataset,
+ tools=[multiply_tool, add_tool],
+ )
+
+ def fake_generate(input_ids, **kwargs):
+ if input_ids.shape[0] == 3:
+ # fmt: off
+ completion_ids = torch.tensor(
+ [
+ # Sample 0: calls multiply_tool → ok (allowed)
+ [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645],
+ # Sample 1: calls multiply_tool → ok (None fallback = all tools)
+ [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645],
+ # Sample 2: plain text
+ [40, 1513, 944, 1414, 894, 5392, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643],
+ ],
+ device=input_ids.device,
+ )
+ # fmt: on
+ else:
+ completion_ids = torch.tensor(
+ [[17453, 0, 151645], [17453, 0, 151645]],
+ device=input_ids.device,
+ )
+ return torch.cat([input_ids, completion_ids], dim=-1)
+
+ with patch.object(trainer.model, "generate", side_effect=fake_generate):
+ trainer.train()
+
+ # Both tool calls should succeed (sample 0 has multiply allowed, sample 1 has all tools via None fallback)
+ assert trainer.state.log_history[-1]["train_loss"] is not None
+ assert trainer.state.log_history[-1]["tools/call_frequency"] == pytest.approx(2 / 3)
+ # No failures: both samples' tool calls are to multiply_tool which is in their allowed set
+ assert trainer.state.log_history[-1]["tools/failure_frequency"] == pytest.approx(0.0)
+
+ @pytest.mark.xfail(
+ condition=Version(transformers.__version__) < Version("5.0.0"),
+ reason="Tool parsing is not supported in transformers versions below 5.0.0",
+ strict=True,
+ )
+ @require_jmespath
+ def test_training_with_async_tool_and_tools_column(self):
+ """Verify that async tools also work with per-sample filtering."""
+ dataset = Dataset.from_dict(
+ {
+ "prompt": [
+ [{"role": "user", "content": "What is 3 times 4?"}],
+ [{"role": "user", "content": "What is 5 plus 6?"}],
+ [{"role": "user", "content": "Tell me a joke."}],
+ ],
+ "tools": [
+ ["multiply_tool"],
+ ["multiply_tool"],
+ ["multiply_tool"],
+ ],
+ }
+ )
+
+ training_args = GRPOConfig(
+ output_dir=self.tmp_dir,
+ learning_rate=0.1,
+ per_device_train_batch_size=3,
+ num_generations=3,
+ max_completion_length=128,
+ report_to="none",
+ )
+ trainer = GRPOTrainer(
+ model="trl-internal-testing/tiny-Qwen3MoeForCausalLM",
+ reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ args=training_args,
+ train_dataset=dataset,
+ tools=[multiply_tool, async_add_tool], # mix of sync and async in global pool
+ )
+
+ def fake_generate(input_ids, **kwargs):
+ if input_ids.shape[0] == 3:
+ # fmt: off
+ completion_ids = torch.tensor(
+ [
+ [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645],
+ [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 66, 788, 220, 19, 11248, 151658, 151645],
+ [40, 1513, 944, 1414, 894, 5392, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643],
+ ],
+ device=input_ids.device,
+ )
+ # fmt: on
+ else:
+ completion_ids = torch.tensor(
+ [[17453, 0, 151645], [17453, 0, 151645]],
+ device=input_ids.device,
+ )
+ return torch.cat([input_ids, completion_ids], dim=-1)
+
+ with patch.object(trainer.model, "generate", side_effect=fake_generate):
+ trainer.train()
+
+ assert trainer.state.log_history[-1]["train_loss"] is not None
+
+ @pytest.mark.xfail(
+ condition=Version(transformers.__version__) < Version("5.2.0"),
+ reason="Environment factory support is not available in transformers versions below 5.2.0",
+ strict=True,
+ )
+ @require_jmespath
+ @patch.dict(os.environ, {"TRL_EXPERIMENTAL_SILENCE": "1"})
+ def test_training_with_tools_column_and_environment_factory(self):
+ """Verify training completes when tools column is combined with environment_factory.
+
+ Note: per-sample tool filtering resolves callables from self.tools, which only holds
+ environment 0's bound methods. All tool calls therefore dispatch to environment 0.
+ This is a known limitation documented in _generate_and_score_completions.
+ """
+ dataset = Dataset.from_dict(
+ {
+ "prompt": [
+ [{"role": "user", "content": "Increment by 1."}],
+ [{"role": "user", "content": "Increment by 1."}],
+ [{"role": "user", "content": "Tell me a joke."}],
+ ],
+ "tools": [
+ ["increment"],
+ ["increment"],
+ ["increment"],
+ ],
+ }
+ )
+
+ class DummyEnvironment:
+ def reset(self, **kwargs):
+ self._counter = 0
+
+ def increment(self, step: int) -> int:
+ """
+ Increment the internal counter.
+
+ Args:
+ step: Value to add to the counter.
+
+ Returns:
+ The updated counter value.
+ """
+ self._counter += step
+ return self._counter
+
+ training_args = GRPOConfig(
+ output_dir=self.tmp_dir,
+ learning_rate=0.1,
+ per_device_train_batch_size=3,
+ num_generations=3,
+ max_completion_length=128,
+ report_to="none",
+ )
+
+ trainer = GRPOTrainer(
+ model="trl-internal-testing/tiny-Qwen3MoeForCausalLM",
+ reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ args=training_args,
+ train_dataset=dataset,
+ environment_factory=DummyEnvironment,
+ )
+
+ def fake_generate(input_ids, **kwargs):
+ if input_ids.shape[0] == 3: # first call
+ # fmt: off
+ completion_ids = torch.tensor(
+ [
+ # Sample 0: '\n{"name": "increment", "arguments": {"step": 1}}\n<|im_end|>'
+ [151657, 198, 4913, 606, 788, 330, 35744, 497, 330, 16370, 788, 5212, 9520, 788, 220, 16, 11248, 151658, 151645, 151643],
+ # Sample 1: '\n{"name": "increment", "arguments": {"step": 1}}\n<|im_end|>'
+ [151657, 198, 4913, 606, 788, 330, 35744, 497, 330, 16370, 788, 5212, 9520, 788, 220, 16, 11248, 151658, 151645, 151643],
+ # Sample 2: "I won't increment<|im_end|>"
+ [40, 2765, 944, 16252, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643],
+ ],
+ device=input_ids.device,
+ )
+ # fmt: on
+ else: # second call: 2 samples had tool calls
+ completion_ids = torch.tensor(
+ [
+ [17453, 0, 151645], # 'Done!<|im_end|>'
+ [17453, 0, 151645], # 'Done!<|im_end|>'
+ ],
+ device=input_ids.device,
+ )
+ return torch.cat([input_ids, completion_ids], dim=-1)
+
+ with patch.object(trainer.model, "generate", side_effect=fake_generate):
+ trainer.train()
+
+ assert trainer.state.log_history[-1]["train_loss"] is not None
+ assert trainer.state.log_history[-1]["tools/call_frequency"] == pytest.approx(2 / 3)
+ assert trainer.state.log_history[-1]["tools/failure_frequency"] == pytest.approx(0.0)
+ # All tool calls resolve to environment 0's bound methods (known limitation):
+ # per_sample_tools is built from self.tools which only holds environment 0's callables.
+ assert trainer.environments[0]._counter == 2
+ assert trainer.environments[1]._counter == 0
+ assert trainer.environments[2]._counter == 0
diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py
index c5eed094192..30297e6c239 100644
--- a/trl/trainer/grpo_trainer.py
+++ b/trl/trainer/grpo_trainer.py
@@ -226,7 +226,9 @@ class GRPOTrainer(_BaseTrainer):
Google-style docstring describing its purpose, arguments, and return value. For more details, see:
https://huggingface.co/docs/transformers/en/chat_extras#passing-tools. The model uses the function's name,
type hints, and docstring to determine how to call it. Ensure that the model's chat template supports tool
- use and that it has been fine-tuned for tool calling.
+ use and that it has been fine-tuned for tool calling. If the dataset contains a `tools` column with a
+ list of tool names per sample, only those tools are exposed during that sample's rollout. If the column
+ value is missing or `None` for a sample, the full `tools` list is used.
rollout_func (`RolloutFunc`, *optional*):
Function to use for generating completions. It receives the list of prompts allocated to the current
process and the trainer instance. It must return a dict with `"prompt_ids"`, `"completion_ids"`, and
@@ -625,6 +627,7 @@ def __init__(
# Reference model
self.beta = args.beta
+
if self.beta == 0.0:
# If beta is 0.0, the reference model is not needed
self.ref_model = None
@@ -850,7 +853,7 @@ def _set_signature_columns_if_needed(self):
# and "attention_mask"). In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't
# work. Instead, we set them to the columns expected by the `training_step` method, hence the override.
if self._signature_columns is None:
- self._signature_columns = ["prompt", "image", "images"]
+ self._signature_columns = ["prompt", "image", "images", "tools"]
# This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy.
# Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an
@@ -1239,8 +1242,15 @@ async def _run_async_funcs():
rewards_per_func = gather(rewards_per_func)
return rewards_per_func
- def _tokenize_prompts(self, prompts: list):
- """Tokenize prompts and extract images/multimodal fields for generation."""
+ def _tokenize_prompts(self, prompts: list, per_sample_tools: "list[list[Callable]] | None" = None):
+ """Tokenize prompts and extract images/multimodal fields for generation.
+
+ Args:
+ prompts: List of prompts (conversational or plain text).
+ per_sample_tools: Optional list of per-sample tool lists. When provided, each prompt is tokenized
+ individually with its own tool list (needed because different samples may expose different tools
+ in the chat template). When ``None``, all prompts are batch-tokenized with ``self.tools``.
+ """
if is_conversational({"prompt": prompts[0]}):
# Extract images from messages for VLM support
images = []
@@ -1256,26 +1266,58 @@ def _tokenize_prompts(self, prompts: list):
images.append(prompt_images if prompt_images else None)
images = images if has_images else None
- # We pass padding=True to work around a bug introduced in transformers 5.2.0 in some processors
- # (e.g. Qwen2.5-VL) that crash on batched unpadded input. We then unpad input_ids using attention_mask.
- # See: https://github.com/huggingface/transformers/issues/44514
- tokenized = self.processing_class.apply_chat_template(
- conversation=prompts,
- tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[]
- chat_template=self.chat_template,
- add_generation_prompt=True,
- tokenize=True,
- return_dict=True,
- padding=True,
- **self.chat_template_kwargs,
- )
- # Unpad input_ids: remove padding tokens using attention_mask to get per-sequence lists
- prompt_ids = [
- [tok for tok, m in zip(ids, mask, strict=True) if m]
- for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True)
- ]
- # For VLMs, the processor returns extra multimodal fields (pixel_values, image_grid_thw, etc.)
- multimodal_fields = {k: v for k, v in tokenized.items() if k not in ("input_ids", "attention_mask")}
+ # When per_sample_tools is provided, we must tokenize each prompt individually because
+ # apply_chat_template doesn't support varying tools across a batch.
+ if per_sample_tools is not None:
+ prompt_ids = []
+ all_multimodal_fields = {}
+ for i, (prompt, tools_for_sample) in enumerate(zip(prompts, per_sample_tools, strict=True)):
+ tokenized = self.processing_class.apply_chat_template(
+ conversation=[prompt],
+ tools=tools_for_sample
+ or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[]
+ chat_template=self.chat_template,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ padding=False,
+ **self.chat_template_kwargs,
+ )
+ prompt_ids.append(tokenized["input_ids"][0])
+ for k, v in tokenized.items():
+ if k not in ("input_ids", "attention_mask"):
+ # apply_chat_template with conversation=[prompt] returns tensors with a leading
+ # batch-of-1 dimension. Index with [0] for both lists and tensors so that
+ # torch.stack below produces shape (N, ...) not (N, 1, ...).
+ all_multimodal_fields.setdefault(k, []).append(
+ v[0] if isinstance(v, (list, torch.Tensor)) else v
+ )
+ # Normalize tensor fields so they are batched consistently with the non-per-sample path.
+ for k, values in all_multimodal_fields.items():
+ if values and isinstance(values[0], torch.Tensor):
+ all_multimodal_fields[k] = torch.stack(values, dim=0)
+ multimodal_fields = all_multimodal_fields
+ else:
+ # We pass padding=True to work around a bug introduced in transformers 5.2.0 in some processors
+ # (e.g. Qwen2.5-VL) that crash on batched unpadded input. We then unpad input_ids using attention_mask.
+ # See: https://github.com/huggingface/transformers/issues/44514
+ tokenized = self.processing_class.apply_chat_template(
+ conversation=prompts,
+ tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[]
+ chat_template=self.chat_template,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ padding=True,
+ **self.chat_template_kwargs,
+ )
+ # Unpad input_ids: remove padding tokens using attention_mask to get per-sequence lists
+ prompt_ids = [
+ [tok for tok, m in zip(ids, mask, strict=True) if m]
+ for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True)
+ ]
+ # For VLMs, the processor returns extra multimodal fields (pixel_values, image_grid_thw, etc.)
+ multimodal_fields = {k: v for k, v in tokenized.items() if k not in ("input_ids", "attention_mask")}
else:
prompt_ids = self.processing_class(text=prompts)["input_ids"]
images = None
@@ -1407,8 +1449,37 @@ def _get_tool_suffix_ids(self, tool_messages):
return full_ids[len(prefix_ids) :]
- def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields):
+ def _tool_call_loop(
+ self,
+ prompts,
+ prompt_ids,
+ completion_ids,
+ completions,
+ logprobs,
+ images,
+ multimodal_fields,
+ per_sample_tools=None,
+ ):
# Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt
+
+ # Build per-sample sync/async tool dicts from the provided callables.
+ if per_sample_tools is not None:
+ sync_tool_dicts = []
+ async_tool_dicts = []
+ for sample_tools in per_sample_tools:
+ sync_dict = {}
+ async_dict = {}
+ for tool in sample_tools:
+ if inspect.iscoroutinefunction(tool):
+ async_dict[tool.__name__] = tool
+ else:
+ sync_dict[tool.__name__] = tool
+ sync_tool_dicts.append(sync_dict)
+ async_tool_dicts.append(async_dict)
+ else:
+ sync_tool_dicts = self._sync_tool_dicts
+ async_tool_dicts = self._async_tool_dicts
+
tool_calls = [completion[0].get("tool_calls") for completion in completions]
idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call]
tool_calls = [tool_calls[idx] for idx in idxs_with_tool]
@@ -1424,8 +1495,8 @@ def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logp
idx_with_tool = idxs_with_tool[idx]
tool_call_list = tool_calls[idx]
prompt_completion_tool = prompt_completion_tools[idx]
- sync_tool_dict = self._sync_tool_dicts[idx_with_tool]
- async_tool_dict = self._async_tool_dicts[idx_with_tool]
+ sync_tool_dict = sync_tool_dicts[idx_with_tool]
+ async_tool_dict = async_tool_dicts[idx_with_tool]
# Append the last assistant message (which triggered tool_calls) to the prompt
prompt_completion_tool.append(completions[idx_with_tool][-1])
async_coros = []
@@ -1586,7 +1657,7 @@ async def _run_async_tools(async_coros):
return tool_mask, completions, completion_ids, logprobs, tool_call_count, tool_failure_count
- def _generate(self, prompts: list):
+ def _generate(self, prompts: list, per_sample_tools: "list[list[Callable]] | None" = None):
device = self.accelerator.device
mode = "train" if self.model.training else "eval"
@@ -1612,7 +1683,7 @@ def _generate(self, prompts: list):
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
prompt_ids, completion_ids, logprobs = output["prompt_ids"], output["completion_ids"], output["logprobs"]
else:
- prompt_ids, images, multimodal_fields = self._tokenize_prompts(prompts)
+ prompt_ids, images, multimodal_fields = self._tokenize_prompts(prompts, per_sample_tools=per_sample_tools)
completion_ids, logprobs = self._generate_single_turn(prompt_ids, images, multimodal_fields)
extra_fields = {}
@@ -1641,7 +1712,14 @@ def _generate(self, prompts: list):
tool_call_count,
tool_failure_count,
) = self._tool_call_loop(
- prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields
+ prompts,
+ prompt_ids,
+ completion_ids,
+ completions,
+ logprobs,
+ images,
+ multimodal_fields,
+ per_sample_tools=per_sample_tools,
)
else:
# Support custom env_mask from rollout_func (e.g., for environment feedback masking)
@@ -1742,6 +1820,23 @@ def _generate_and_score_completions(
for prompt, image_list in zip(prompts, images, strict=True)
]
+ # If the dataset has a "tools" column, use it for per-sample tool filtering.
+ # Each sample lists the tool names it wants; missing/None falls back to all tools.
+ # Note: combining a `tools` column with `environment_factory` is not supported. When
+ # environments are used, `self.tools` holds environment 0's bound methods (representative
+ # for schema rendering), so all per-sample callables would resolve to environment 0's
+ # instance instead of each sample's own environment.
+ per_sample_tools = None
+ if self.tools and inputs and "tools" in inputs[0]:
+ tool_lookup = {tool.__name__: tool for tool in self.tools}
+ per_sample_tools = []
+ for example in inputs:
+ tool_names = example.get("tools")
+ if tool_names is not None:
+ per_sample_tools.append([tool_lookup[name] for name in tool_names])
+ else:
+ per_sample_tools.append(list(tool_lookup.values()))
+
(
prompt_ids_list,
completion_ids_list,
@@ -1750,7 +1845,7 @@ def _generate_and_score_completions(
num_items_in_batch,
sampling_per_token_logps_list,
extra_fields,
- ) = self._generate(prompts)
+ ) = self._generate(prompts, per_sample_tools=per_sample_tools)
# Convert lists of token IDs to padded tensors
prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list]