Skip to content

Add per-sample tool filtering to GRPOTrainer via tools column#5398

Open
lailanelkoussy wants to merge 21 commits intohuggingface:mainfrom
lailanelkoussy:feat/per-sample-tool-filtering
Open

Add per-sample tool filtering to GRPOTrainer via tools column#5398
lailanelkoussy wants to merge 21 commits intohuggingface:mainfrom
lailanelkoussy:feat/per-sample-tool-filtering

Conversation

@lailanelkoussy
Copy link
Copy Markdown
Contributor

@lailanelkoussy lailanelkoussy commented Mar 27, 2026

Add per-sample tool filtering to GRPOTrainer

What does this PR do?

Adds per-sample tool filtering to GRPOTrainer. When the dataset contains a "tools" column,
each sample can specify which subset of the global tools pool is available, controlling both the
prompt schema (which tools the model sees) and execution (which tool calls are allowed to run).

Motivation

In agentic GRPO training, different samples often require different tool subsets. For example, a math
problem should only expose a calculator tool, while a translation task should only expose a
translator. Without this feature, every sample sees the full tool list, leading to:

  • Noisy learning signal: the model wastes exploration budget on irrelevant tools
  • Incorrect reward attribution: tool call failures due to schema mismatch confuse the reward signal
  • No curriculum control: no way to progressively introduce tools during training

API

from datasets import Dataset
from trl import GRPOTrainer

dataset = Dataset.from_dict({
    "prompt": [
        [{"role": "user", "content": "What is 2+2?"}],
        [{"role": "user", "content": "Translate 'hello'"}],
    ],
    "tools": [["calculator"], ["translator"]],  # ← per-sample tool column, auto-detected
})

trainer = GRPOTrainer(
    model=model,
    reward_funcs=reward,
    tools=[calculator, translator],
    train_dataset=dataset,
)

When the dataset has no tools column, behavior is identical to the existing API — full backward compatibility.

Implementation details

Area Change
_set_signature_columns_if_needed Always includes "tools" so the column passes through the dataloader
_tokenize_prompts New per_sample_tools param; tokenizes each prompt individually with its own tool schema when set
_tool_call_loop New per_sample_tools param; builds per-sample sync/async tool dicts for execution filtering
_generate Threads per_sample_tools to tokenization and tool-call loop
_generate_and_score_completions Auto-detects "tools" key in batch; resolves tool names via per-sample _sync_tool_dicts/_async_tool_dicts; None entries fall back to the full tool list

Tests

New test file tests/test_grpo_tools_column.py:

Validation tests:

  • Valid configuration (dataset with tools column) passes without error
  • Backward compatibility: no tools column → trainer uses all tools
  • Per-sample tool dict correctness
  • Signature columns always include "tools"

Integration tests:

  • End-to-end training with a tools dataset column (all samples, same tools)
  • Training with tool subset restriction (some calls fail as expected)
  • Backward compatibility (tools provided, no dataset tools column)
  • None in tools column falls back to full tool list
  • Async tools with per-sample filtering
  • Environment factory with per-sample filtering (correct per-environment method binding)

Files changed

  • trl/trainer/grpo_trainer.py — core implementation
  • tests/test_grpo_tools_column.py — new file (test suite)
  • docs/source/per_sample_tools.md — new file (documentation)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case)
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Note

Medium Risk
Touches GRPO rollout/tokenization and tool-execution paths; incorrect filtering could change prompt schemas or tool dispatch during training despite backward-compat fallbacks and added tests.

Overview
Adds per-sample tool filtering to GRPOTrainer: when the training dataset includes a tools column (list of tool __name__ strings per row), the trainer now restricts both the rendered chat tool schema and runtime tool execution to that subset, with None/missing values falling back to the full global tools pool.

This threads a new per_sample_tools flow through prompt tokenization, generation, and the tool-call loop, and ensures tools is kept as a signature column so it isn’t dropped by column-pruning. Documentation is added (docs/source/per_sample_tools.md + toctree entry) and a comprehensive test suite is introduced covering validation, backward compatibility, subset behavior, None fallback, async tools, and interaction notes with environment_factory.

Written by Cursor Bugbot for commit 049b300. This will update automatically on new commits. Configure here.

Copilot AI review requested due to automatic review settings March 27, 2026 22:48
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an optional tools_column_name to GRPOTrainer to enable per-sample tool exposure/execution filtering, allowing each dataset row to restrict the available tools (schema + execution) to a subset of the global tools pool.

Changes:

  • Extend GRPOTrainer with tools_column_name, a tool-name registry, dataset validation, and per-sample tool threading through tokenization and tool execution.
  • Add a dedicated test suite covering validation and end-to-end training behavior with per-sample tool filtering.
  • Add documentation describing how to use per-sample tool filtering.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 9 comments.

File Description
trl/trainer/grpo_trainer.py Implements per-sample tool resolution, dataset validation, per-sample tokenization behavior, and filtered tool execution.
tests/test_grpo_tools_column.py Adds validation + integration tests for tools_column_name behavior and backward compatibility.
docs/source/per_sample_tools.md Adds user-facing documentation and examples for per-sample tool filtering.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

- Validate tool name uniqueness before building _tool_registry; raise
  ValueError on duplicates to prevent silent overwrites
- Move tools_column_name + empty-tools check before super().__init__()
  to fail fast before model initialization
- Replace per-row dataset iteration with vectorized column access for
  efficient init-time validation
- Apply  guard to tools_for_sample in per-sample tokenization
  path to match existing Llama tools=[] workaround
- Stack multimodal tensor fields after per-sample accumulation so
  downstream generation sees correctly batched shapes
- Fix _tool_call_loop to filter existing _sync/_async_tool_dicts by
  allowed names instead of rebuilding from raw callables, preserving
  correct environment instance binding
- Resolve per-sample callables from per-sample env-bound dicts rather
  than the global _tool_registry (which only holds env[0] methods)
- Replace eval() in docs example with a safe AST-based calculator
- Add per_sample_tools page to _toctree.yml How-to guides section
- Gate duplicate tool name check on  to avoid
  breaking existing setups that don't use per-sample filtering
- Replace  with
  (consistent with rest of file; every tool callable has __name__)
- Fix multimodal tensor batch dimension in per-sample tokenization:
  index tensors with [0] alongside lists so torch.stack produces
  shape (N, ...) instead of (N, 1, ...) for VLM pixel_values etc.
- Remove self._tool_registry (registry pattern prohibited by AGENTS.md);
  replace dataset validation with a set comprehension over self.tools
- Remove overly defensive bounds checks (if idx < len(_sync_tool_dicts))
  in _tool_call_loop and per_sample_tools resolution; these cases cannot
  occur since the dicts are sized to generation_batch_size
- Remove dead fallback path that resolved tools from _tool_registry when
  _sync/_async_tool_dicts were empty (unreachable with self.tools guard)
- Fix per_sample_tools resolution to always derive callables from
  _sync_tool_dicts[i]/_async_tool_dicts[i], preserving correct
  environment binding per sample
- Update tests that asserted on trainer._tool_registry to check
  trainer._sync_tool_dicts[0] and trainer.tools instead
- Add test_training_with_tools_column_and_environment_factory to cover
  tools_column_name + environment_factory; verifies each sample uses its
  own environment's bound method rather than environments[0]'s
Use (, *optional*) instead of (, *optional*, defaults to )
for None-defaulted parameters, consistent with all other optional params
in the same docstring (tools, peft_config, environment_factory, etc.).
Copy link
Copy Markdown
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some early review, I think this PR can be massively simplified. No need to be defense, and add new parameters. Just supporting the extra column tools is enough

@lailanelkoussy
Copy link
Copy Markdown
Contributor Author

Thanks for the feedback! I've simplified the implementation: removed the tools_column_name parameter entirely and the validation blocks. The trainer now just auto-detects a tools column in the dataset. Happy to adjust further if needed.

_sync_tool_dicts and _async_tool_dicts are sized to generation_batch_size (train batch), so indexing them by position over eval inputs causes an IndexError when per_device_eval_batch_size exceeds that size.

In _generate_and_score_completions, build the tool lookup directly from self.tools instead of indexing _sync_tool_dicts. In _tool_call_loop, build per-sample sync/async dicts from the per_sample_tools callables directly using inspect.iscoroutinefunction, mirroring the init pattern.
@lailanelkoussy lailanelkoussy changed the title Add per-sample tool filtering to GRPOTrainer via tools_column_name Add per-sample tool filtering to GRPOTrainer via tools column Mar 30, 2026
  The test asserted per-environment binding (counter == 1 each) but
  per_sample_tools resolves from self.tools which only holds environment
  0's callables. Assertions now reflect the actual behavior (all calls
  dispatch to environment 0) consistent with the documented limitation.
@lailanelkoussy
Copy link
Copy Markdown
Contributor Author

Thanks for the review @qgallouedec! Here's a summary of the main changes since I made your feedback:

Addressing your comments:

  • Removed tools_column_name parameter entirely — the trainer now auto-detects a "tools" column in the dataset, no new init arg needed.
  • Removed duplicate tool name validation — agreed it was unnecessary defensiveness.
  • Removed all dataset-level validation (the init-time iteration over dataset rows checking for unknown tool names) — too defensive per your feedback.
  • "tools" is now unconditionally included in _signature_columns so the column passes through the dataloader when present.

tools column + environment_factory — intentional non-support:

When environment_factory is set, the trainer creates N environment instances (one per generation_batch_size slot) and binds each environment's methods into per-slot tool dicts (_sync_tool_dicts[i] / _async_tool_dicts[i]) so that sample i executes on environment i. However, self.tools only stores environment 0's bound methods (line 491: tools + environment_methods[0]), because self.tools is used for schema/tokenization purposes and all environments share the same method signatures.

Per-sample tool filtering builds its lookup from self.tools, which means all resolved callables point to environment 0's instance. Making this work correctly would require cross-referencing the dataset tools column with the per-slot _sync_tool_dicts[i] / _async_tool_dicts[i], coupling the filtering logic tightly to the environment lifecycle. This adds real complexity for a use case that doesn't have a clear practical motivation yet — when using environments, the environment itself typically controls which tools are available.

Rather than adding that complexity now, I chose to document this as a known limitation (comment in _generate_and_score_completions) and write a test that verifies training still completes without errors, asserting the actual behavior (all calls dispatch to environment 0) rather than the ideal behavior.

Let me know if you agree with this approach or if you'd rather I do things differently. Thanks in advance

# Normalize tensor fields so they are batched consistently with the non-per-sample path.
for k, values in all_multimodal_fields.items():
if values and isinstance(values[0], torch.Tensor):
all_multimodal_fields[k] = torch.stack(values, dim=0)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant per-sample tokenization for repeated generations

Low Severity

When per_sample_tools is set, every prompt in the generation batch is tokenized individually via apply_chat_template. Because the dataloader repeats each unique prompt num_generations times (via RepeatSampler), prompts with the same tool list are tokenized redundantly — e.g., with num_generations=8 and batch size 4, this results in 32 individual tokenizer calls instead of potentially far fewer. The batch path avoids this with a single batched apply_chat_template call. For per-sample tools, grouping prompts by identical tool sets and batch-tokenizing each group would avoid the redundant work.

Fix in Cursor Fix in Web

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

There are 2 total unresolved issues (including 1 from previous review).

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

- local: use_model
title: Using Trained Models
- local: per_sample_tools
title: Per-Sample Tool Filtering
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong toctree indentation breaks How-to guides section

Medium Severity

The new per_sample_tools entry uses 2-space indentation instead of 4-space, placing it at the top-level YAML list rather than inside the sections list of "How-to guides." This causes the title: How-to guides on line 44 to detach from its section and become a duplicate title key on the new standalone item. The "How-to guides" section loses its title, and the documentation navigation breaks.

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants