Skip to content

Narrow prefix-preserving check to the actual requirement#5458

Merged
qgallouedec merged 14 commits intomainfrom
narrow-prefix-preserving-check
Apr 7, 2026
Merged

Narrow prefix-preserving check to the actual requirement#5458
qgallouedec merged 14 commits intomainfrom
narrow-prefix-preserving-check

Conversation

@qgallouedec
Copy link
Copy Markdown
Member

@qgallouedec qgallouedec commented Apr 5, 2026

Before #5224, the tool call loop re-tokenized the full conversation at each iteration, requiring prefix-preservation across all role transitions. Now that this is fixed, only _get_tool_suffix_ids relies on prefix-preservation, and only for the narrow [user, assistant] → [user, assistant, tool] case. This PR aligns the check with the actual requirement, which in turn reveals that Qwen3.5 doesn't need patching anymore.

This PR:

  • is_chat_template_prefix_preserving now tests the exact property that _get_tool_suffix_ids relies on: appending tool messages after [user, assistant] must not change the rendering of earlier messages. (Previously it tested broader user→assistant and assistant→user transitions.)
  • This revealed that Qwen3.5 was already prefix-preserving for the tool transition, only Qwen3 actually needs the training template patch. Removed the now-dead Qwen3.5 training template code and _patch_qwen3_5_training_template.
  • Updated tests accordingly (removed Qwen3.5 from TestGetTrainingChatTemplate, added tool role to test templates).

Note

Medium Risk
Changes affect how tool-response suffix tokens are derived and when chat templates are considered safe for tool-calling, which can subtly impact multi-turn/tool formatting across tokenizers. Scope is limited and covered by updated tests, but regressions would show up as incorrect tool token slicing or template patching decisions.

Overview
Aligns prefix-preserving validation with the actual tool-calling need. is_chat_template_prefix_preserving now specifically checks that appending a tool message after a [user, assistant(tool_calls)] prefix does not change the rendered prefix, matching what _get_tool_suffix_ids relies on.

Simplifies training-template patching and updates tool suffix extraction. Removes Qwen3.5 training-template patching paths (only Qwen3 is patched now), updates docs/examples accordingly, and adjusts _get_tool_suffix_ids in both AsyncRolloutWorker and GRPOTrainer to use a dummy assistant tool_calls message plus safer EOS trimming when extracting the suffix.

Tests updated for tool-aware templates. Chat-template tests add tool/tool_calls handling in templates and drop Qwen3.5 from TestGetTrainingChatTemplate parametrization.

Reviewed by Cursor Bugbot for commit 87131e5. Bugbot is set up for automated code reviews on this repo. Configure here.

@qgallouedec
Copy link
Copy Markdown
Member Author

@codex review

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex Review: Didn't find any major issues. Bravo.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit cd8cbfc. Configure here.

Comment on lines 556 to 578
def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int]:
"""Get token IDs for tool result formatting by using a minimal dummy conversation."""
dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}]
dummy_messages = [
{"role": "user", "content": ""},
{"role": "assistant", "content": ""},
{"role": "user", "content": "dummy"},
{"role": "assistant", "tool_calls": dummy_tool_calls},
]
prefix_ids = self.tokenizer.apply_chat_template(
dummy_messages,
return_dict=False,
tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[]
add_generation_prompt=False,
tokenize=True,
chat_template=self.chat_template,
return_dict=False,
**self.chat_template_kwargs,
)
full_ids = self.tokenizer.apply_chat_template(
dummy_messages + tool_messages,
return_dict=False,
chat_template=self.chat_template,
add_generation_prompt=True,
tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[]
tokenize=True,
chat_template=self.chat_template,
return_dict=False,
**self.chat_template_kwargs,
)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to align GRPO and async GRPO

Comment on lines -580 to +586
# EOS (not EOS + newline).
last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id)
prefix_ids = prefix_ids[: last_eos_idx + 1]
# EOS (not EOS + newline). Templates that don't use EOS as end-of-turn (e.g. Gemma uses
# <turn|>) skip this trimming.
eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id]
if eos_positions:
prefix_ids = prefix_ids[: eos_positions[-1] + 1]
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines -1409 to +1413
dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}]
dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}]
dummy_messages = [
{"role": "user", "content": "dummy"},
{"role": "assistant", "tool_calls": dummy_tool_calls},
]
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's closer to real data to have assistant turn that contains a tool call. It's probably equivalent in most template, but it's cleaner this way imo

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I commented above:

-             {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls},
+             {"role": "assistant", "tool_calls": dummy_tool_calls},

messages = [
{"role": "user", "content": "Ping the service."},
{"role": "assistant", "tool_calls": tool_calls},
{"role": "assistant", "content": "", "tool_calls": tool_calls},
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure of this? Some lines above, we have: https://github.com/huggingface/transformers/blob/e1b80de84d3c5da35669b2834ef017eeaf620f93/src/transformers/processing_utils.py#L1789-L1790

                if not isinstance(message.get("content"), list):
                    continue
Suggested change
{"role": "assistant", "content": "", "tool_calls": tool_calls},
{"role": "assistant", "tool_calls": tool_calls},

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strange, I remember that it failed without the content, let's try again

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok I get it, wrong line:
https://github.com/huggingface/transformers/blob/e1b80de84d3c5da35669b2834ef017eeaf620f93/src/transformers/processing_utils.py#L1807

without the content: "", you'd get

        if tokenize:
            batch_images, batch_videos = [], []
            batch_audios = []
            for conversation in conversations:
                images, videos = [], []
                for message in conversation:
>                   visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
                                                      ^^^^^^^^^^^^^^^^^^
E                   KeyError: 'content'

../miniconda3/envs/trl/lib/python3.13/site-packages/transformers/processing_utils.py:1807: KeyError
--------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------
The fast path is not available because one of the required library is not installed. Falling back to torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and https://github.com/Dao-AILab/causal-conv1d
Loading weights: 100%|██████████| 60/60 [00:00<00:00, 3877.51it/s]
Loading weights: 100%|██████████| 27/27 [00:00<00:00, 12929.13it/s]
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 248046, 'pad_token_id': 248044}.
  0%|          | 0/51 [00:00<?, ?it/s]Kwargs passed to `processor.__call__` have to be in `processor_kwargs` dict, not in `**kwargs`
========================================================================= short test summary info ==========================================================================
FAILED tests/test_grpo_trainer.py::TestGRPOTrainer::test_training_with_tools_multimodal_response - KeyError: 'content'

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reconsideration, I think it's likely a transformers bug.

huggingface/transformers#45290

I suggest keeping this patch for now so that it's not blocking for users

Copy link
Copy Markdown
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Some comments and questions below.

messages = [
{"role": "user", "content": "Ping the service."},
{"role": "assistant", "tool_calls": tool_calls},
{"role": "assistant", "content": "", "tool_calls": tool_calls},
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure of this? Some lines above, we have: https://github.com/huggingface/transformers/blob/e1b80de84d3c5da35669b2834ef017eeaf620f93/src/transformers/processing_utils.py#L1789-L1790

                if not isinstance(message.get("content"), list):
                    continue
Suggested change
{"role": "assistant", "content": "", "tool_calls": tool_calls},
{"role": "assistant", "tool_calls": tool_calls},



qwen3_5_training_chat_template_2b_and_below = _patch_qwen3_5_training_template(qwen3_5_chat_template_2b_and_below)
qwen3_5_training_chat_template_4b_and_above = _patch_qwen3_5_training_template(qwen3_5_chat_template_4b_and_above)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mention Qwen3.5 doesn't need patching anymore. Does this depend on the transformers version?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, it's really about the chat template:

from transformers import AutoTokenizer, AutoProcessor

dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}]

messages1 = [
    {"role": "user", "content": "dummy"},
    {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls},
]
messages2 = messages1 + [
    {"role": "tool", "name": "dummy", "content": "dummy"},
]


model_id = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
text1 = tokenizer.apply_chat_template(messages1, tokenize=False)
text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
print(f"\n{'='*60}")
print(f"{model_id}")
print(f"Prefix-preserving: {text2.startswith(text1)}")
print(repr(text1))
print(repr(text2))

model_id = "Qwen/Qwen3.5-0.8B"
tokenizer = AutoProcessor.from_pretrained(model_id)
text1 = tokenizer.apply_chat_template(messages1, tokenize=False)
text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
print(f"\n{'='*60}")
print(f"{model_id}")
print(f"Prefix-preserving: {text2.startswith(text1)}")
print(repr(text1))
print(repr(text2))
============================================================
Qwen/Qwen3-0.6B
Prefix-preserving: False
'<|im_start|>user\ndummy<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n{"name": "dummy", "arguments": {}}\n</tool_call><|im_end|>\n'
'<|im_start|>user\ndummy<|im_end|>\n<|im_start|>assistant\n<tool_call>\n{"name": "dummy", "arguments": {}}\n</tool_call><|im_end|>\n<|im_start|>user\n<tool_response>\ndummy\n</tool_response><|im_end|>\n<|im_start|>assistant\n'

============================================================
Qwen/Qwen3.5-0.8B
Prefix-preserving: True
'<|im_start|>user\ndummy<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n<function=dummy>\n</function>\n</tool_call><|im_end|>\n'
'<|im_start|>user\ndummy<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n<function=dummy>\n</function>\n</tool_call><|im_end|>\n<|im_start|>user\n<tool_response>\ndummy\n</tool_response><|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n'

Comment on lines -1409 to +1413
dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}]
dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}]
dummy_messages = [
{"role": "user", "content": "dummy"},
{"role": "assistant", "tool_calls": dummy_tool_calls},
]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I commented above:

-             {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls},
+             {"role": "assistant", "tool_calls": dummy_tool_calls},

@qgallouedec qgallouedec merged commit f19fda4 into main Apr 7, 2026
16 checks passed
@qgallouedec qgallouedec deleted the narrow-prefix-preserving-check branch April 7, 2026 14:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants