Add per-sample tool filtering to GRPOTrainer via tools column#5398
Add per-sample tool filtering to GRPOTrainer via tools column#5398lailanelkoussy wants to merge 21 commits intohuggingface:mainfrom
tools column#5398Conversation
There was a problem hiding this comment.
Pull request overview
Adds an optional tools_column_name to GRPOTrainer to enable per-sample tool exposure/execution filtering, allowing each dataset row to restrict the available tools (schema + execution) to a subset of the global tools pool.
Changes:
- Extend
GRPOTrainerwithtools_column_name, a tool-name registry, dataset validation, and per-sample tool threading through tokenization and tool execution. - Add a dedicated test suite covering validation and end-to-end training behavior with per-sample tool filtering.
- Add documentation describing how to use per-sample tool filtering.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 9 comments.
| File | Description |
|---|---|
trl/trainer/grpo_trainer.py |
Implements per-sample tool resolution, dataset validation, per-sample tokenization behavior, and filtered tool execution. |
tests/test_grpo_tools_column.py |
Adds validation + integration tests for tools_column_name behavior and backward compatibility. |
docs/source/per_sample_tools.md |
Adds user-facing documentation and examples for per-sample tool filtering. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Validate tool name uniqueness before building _tool_registry; raise ValueError on duplicates to prevent silent overwrites - Move tools_column_name + empty-tools check before super().__init__() to fail fast before model initialization - Replace per-row dataset iteration with vectorized column access for efficient init-time validation - Apply guard to tools_for_sample in per-sample tokenization path to match existing Llama tools=[] workaround - Stack multimodal tensor fields after per-sample accumulation so downstream generation sees correctly batched shapes - Fix _tool_call_loop to filter existing _sync/_async_tool_dicts by allowed names instead of rebuilding from raw callables, preserving correct environment instance binding - Resolve per-sample callables from per-sample env-bound dicts rather than the global _tool_registry (which only holds env[0] methods) - Replace eval() in docs example with a safe AST-based calculator - Add per_sample_tools page to _toctree.yml How-to guides section
- Gate duplicate tool name check on to avoid breaking existing setups that don't use per-sample filtering - Replace with (consistent with rest of file; every tool callable has __name__) - Fix multimodal tensor batch dimension in per-sample tokenization: index tensors with [0] alongside lists so torch.stack produces shape (N, ...) instead of (N, 1, ...) for VLM pixel_values etc.
- Remove self._tool_registry (registry pattern prohibited by AGENTS.md); replace dataset validation with a set comprehension over self.tools - Remove overly defensive bounds checks (if idx < len(_sync_tool_dicts)) in _tool_call_loop and per_sample_tools resolution; these cases cannot occur since the dicts are sized to generation_batch_size - Remove dead fallback path that resolved tools from _tool_registry when _sync/_async_tool_dicts were empty (unreachable with self.tools guard) - Fix per_sample_tools resolution to always derive callables from _sync_tool_dicts[i]/_async_tool_dicts[i], preserving correct environment binding per sample - Update tests that asserted on trainer._tool_registry to check trainer._sync_tool_dicts[0] and trainer.tools instead - Add test_training_with_tools_column_and_environment_factory to cover tools_column_name + environment_factory; verifies each sample uses its own environment's bound method rather than environments[0]'s
Use (, *optional*) instead of (, *optional*, defaults to ) for None-defaulted parameters, consistent with all other optional params in the same docstring (tools, peft_config, environment_factory, etc.).
qgallouedec
left a comment
There was a problem hiding this comment.
Some early review, I think this PR can be massively simplified. No need to be defense, and add new parameters. Just supporting the extra column tools is enough
…tools_column_name param
…ailanelkoussy/trl into feat/per-sample-tool-filtering
|
Thanks for the feedback! I've simplified the implementation: removed the |
_sync_tool_dicts and _async_tool_dicts are sized to generation_batch_size (train batch), so indexing them by position over eval inputs causes an IndexError when per_device_eval_batch_size exceeds that size. In _generate_and_score_completions, build the tool lookup directly from self.tools instead of indexing _sync_tool_dicts. In _tool_call_loop, build per-sample sync/async dicts from the per_sample_tools callables directly using inspect.iscoroutinefunction, mirroring the init pattern.
tools_column_nametools column
The test asserted per-environment binding (counter == 1 each) but per_sample_tools resolves from self.tools which only holds environment 0's callables. Assertions now reflect the actual behavior (all calls dispatch to environment 0) consistent with the documented limitation.
|
Thanks for the review @qgallouedec! Here's a summary of the main changes since I made your feedback: Addressing your comments:
When Per-sample tool filtering builds its lookup from Rather than adding that complexity now, I chose to document this as a known limitation (comment in Let me know if you agree with this approach or if you'd rather I do things differently. Thanks in advance |
| # Normalize tensor fields so they are batched consistently with the non-per-sample path. | ||
| for k, values in all_multimodal_fields.items(): | ||
| if values and isinstance(values[0], torch.Tensor): | ||
| all_multimodal_fields[k] = torch.stack(values, dim=0) |
There was a problem hiding this comment.
Redundant per-sample tokenization for repeated generations
Low Severity
When per_sample_tools is set, every prompt in the generation batch is tokenized individually via apply_chat_template. Because the dataloader repeats each unique prompt num_generations times (via RepeatSampler), prompts with the same tool list are tokenized redundantly — e.g., with num_generations=8 and batch size 4, this results in 32 individual tokenizer calls instead of potentially far fewer. The batch path avoids this with a single batched apply_chat_template call. For per-sample tools, grouping prompts by identical tool sets and batch-tokenizing each group would avoid the redundant work.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
There are 2 total unresolved issues (including 1 from previous review).
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
| - local: use_model | ||
| title: Using Trained Models | ||
| - local: per_sample_tools | ||
| title: Per-Sample Tool Filtering |
There was a problem hiding this comment.
Wrong toctree indentation breaks How-to guides section
Medium Severity
The new per_sample_tools entry uses 2-space indentation instead of 4-space, placing it at the top-level YAML list rather than inside the sections list of "How-to guides." This causes the title: How-to guides on line 44 to detach from its section and become a duplicate title key on the new standalone item. The "How-to guides" section loses its title, and the documentation navigation breaks.


Add per-sample tool filtering to
GRPOTrainerWhat does this PR do?
Adds per-sample tool filtering to
GRPOTrainer. When the dataset contains a"tools"column,each sample can specify which subset of the global
toolspool is available, controlling both theprompt schema (which tools the model sees) and execution (which tool calls are allowed to run).
Motivation
In agentic GRPO training, different samples often require different tool subsets. For example, a math
problem should only expose a
calculatortool, while a translation task should only expose atranslator. Without this feature, every sample sees the full tool list, leading to:API
When the dataset has no
toolscolumn, behavior is identical to the existing API — full backward compatibility.Implementation details
_set_signature_columns_if_needed"tools"so the column passes through the dataloader_tokenize_promptsper_sample_toolsparam; tokenizes each prompt individually with its own tool schema when set_tool_call_loopper_sample_toolsparam; builds per-sample sync/async tool dicts for execution filtering_generateper_sample_toolsto tokenization and tool-call loop_generate_and_score_completions"tools"key in batch; resolves tool names via per-sample_sync_tool_dicts/_async_tool_dicts;Noneentries fall back to the full tool listTests
New test file
tests/test_grpo_tools_column.py:Validation tests:
"tools"Integration tests:
toolsdataset column (all samples, same tools)Nonein tools column falls back to full tool listFiles changed
trl/trainer/grpo_trainer.py— core implementationtests/test_grpo_tools_column.py— new file (test suite)docs/source/per_sample_tools.md— new file (documentation)Before submitting
Note
Medium Risk
Touches GRPO rollout/tokenization and tool-execution paths; incorrect filtering could change prompt schemas or tool dispatch during training despite backward-compat fallbacks and added tests.
Overview
Adds per-sample tool filtering to
GRPOTrainer: when the training dataset includes atoolscolumn (list of tool__name__strings per row), the trainer now restricts both the rendered chat tool schema and runtime tool execution to that subset, withNone/missing values falling back to the full globaltoolspool.This threads a new
per_sample_toolsflow through prompt tokenization, generation, and the tool-call loop, and ensurestoolsis kept as a signature column so it isn’t dropped by column-pruning. Documentation is added (docs/source/per_sample_tools.md+ toctree entry) and a comprehensive test suite is introduced covering validation, backward compatibility, subset behavior,Nonefallback, async tools, and interaction notes withenvironment_factory.Written by Cursor Bugbot for commit 049b300. This will update automatically on new commits. Configure here.