diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py
index 9fa45e5f71..2cdd45b85f 100644
--- a/tests/test_chat_template_utils.py
+++ b/tests/test_chat_template_utils.py
@@ -146,13 +146,25 @@ def test_add_response_schema(self, tokenizer_name):
class TestIsChatTemplatePrefixPreserving:
def test_prefix_preserving_template(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification")
+ # docstyle-ignore
tokenizer.chat_template = textwrap.dedent(r"""
{%- for message in messages %}
{%- if message.role == 'user' %}
{{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }}
{%- elif message.role == 'assistant' %}
- {{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }}
+ {{- '<|im_start|>assistant\n' + message.content }}
+ {%- if message.tool_calls %}
+ {%- for tool_call in message.tool_calls %}
+ {%- if tool_call.function %}
+ {%- set tool_call = tool_call.function %}
+ {%- endif %}
+ {{- '' + tool_call.name + '' }}
+ {%- endfor %}
+ {%- endif %}
+ {{- '<|im_end|>\n' }}
+ {%- elif message.role == 'tool' %}
+ {{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }}
{%- endif %}
{%- endfor %}
@@ -164,8 +176,9 @@ def test_prefix_preserving_template(self):
def test_non_prefix_preserving_template(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification")
- # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part is
- # only present for last assistant message, which makes it non-prefix-preserving.
+ # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part (even
+ # empty) is only present for last assistant message, which makes it non-prefix-preserving: appending a tool
+ # message changes the earlier output.
# docstyle-ignore
tokenizer.chat_template = textwrap.dedent(r"""
{%- if messages[0].role == 'system' %}
@@ -202,7 +215,17 @@ def test_non_prefix_preserving_template(self):
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
+ {%- if message.tool_calls %}
+ {%- for tool_call in message.tool_calls %}
+ {%- if tool_call.function %}
+ {%- set tool_call = tool_call.function %}
+ {%- endif %}
+ {{- '' + tool_call.name + '' }}
+ {%- endfor %}
+ {%- endif %}
{{- '<|im_end|>\n' }}
+ {%- elif message.role == "tool" %}
+ {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
@@ -218,14 +241,6 @@ def test_non_prefix_preserving_template(self):
"tokenizer_name",
[
pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"),
- pytest.param(
- "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration",
- id="qwen35",
- marks=pytest.mark.skipif(
- Version(transformers.__version__) < Version("5.0.0"),
- reason="Qwen3.5 tokenizer requires transformers>=5.0.0",
- ),
- ),
],
)
class TestGetTrainingChatTemplate:
diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py
index 86317a5467..4ea93b92ad 100644
--- a/trl/chat_template_utils.py
+++ b/trl/chat_template_utils.py
@@ -630,6 +630,10 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool:
"""
Check whether the chat template preserves prefixes when applied.
+ A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. This
+ property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by comparing
+ tokenizations with and without tool messages appended.
+
Args:
tokenizer (`PreTrainedTokenizer`):
Tokenizer instance to check.
@@ -638,24 +642,22 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool:
`bool`:
`True` if the chat template preserves prefixes, `False` otherwise.
"""
+ # Use the same dummy messages as _get_tool_suffix_ids to test the exact property it relies on.
+ dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}]
messages1 = [
- {"role": "user", "content": "What color is the sky?"},
+ {"role": "user", "content": "dummy"},
+ {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls},
]
messages2 = [
- {"role": "user", "content": "What color is the sky?"},
- {"role": "assistant", "content": "It is blue."},
- ]
- messages3 = [
- {"role": "user", "content": "What color is the sky?"},
- {"role": "assistant", "content": "It is blue."},
- {"role": "user", "content": "And at night?"},
+ {"role": "user", "content": "dummy"},
+ {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls},
+ {"role": "tool", "name": "dummy", "content": "dummy"},
]
- text1 = tokenizer.apply_chat_template(messages1, tokenize=False, add_generation_prompt=True)
- text2 = tokenizer.apply_chat_template(messages2, tokenize=False)
- text3 = tokenizer.apply_chat_template(messages3, tokenize=False)
+ text1 = tokenizer.apply_chat_template(messages1, tokenize=False)
+ text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
- return text2.startswith(text1) and text3.startswith(text2)
+ return text2.startswith(text1)
# Modifications:
@@ -749,33 +751,12 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool:
{%- endif %}"""
-# Modifications:
-# - {%- if '' in content %}
-# + {%- if '' in content and '' in content %}
-# Always check for both tags to avoid edge cases where the model generates only one tag, which would otherwise be parsed incorrectly
-# - {{- '<|im_start|>' + message.role + '\n' + content }}
-# + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }}
-# Always include thinking block during training. It's important to have a prefix-preserving template.
-def _patch_qwen3_5_training_template(template: str) -> str:
- return template.replace(
- "{%- if '' in content %}",
- "{%- if '' in content and '' in content %}",
- ).replace(
- "{{- '<|im_start|>' + message.role + '\\n' + content }}",
- "{{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content + '\\n\\n\\n' + content }}",
- )
-
-
-qwen3_5_training_chat_template_2b_and_below = _patch_qwen3_5_training_template(qwen3_5_chat_template_2b_and_below)
-qwen3_5_training_chat_template_4b_and_above = _patch_qwen3_5_training_template(qwen3_5_chat_template_4b_and_above)
-
-
def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None:
r"""
Get a prefix-preserving chat template for training, if needed.
- If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 and
- Qwen3.5 supported). Otherwise, returns `None`.
+ If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3
+ supported). Otherwise, returns `None`.
Args:
tokenizer (`PreTrainedTokenizer`):
@@ -793,27 +774,31 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None:
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
>>> messages1 = [
- ... {"role": "user", "content": "What color is the sky?"},
- ... {"role": "assistant", "content": "It is blue."},
+ ... {"role": "user", "content": "What is 2 * 3?"},
+ ... {
+ ... "role": "assistant",
+ ... "content": "",
+ ... "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}],
+ ... },
... ]
- >>> messages2 = [
- ... {"role": "user", "content": "What color is the sky?"},
- ... {"role": "assistant", "content": "It is blue."},
- ... {"role": "user", "content": "And at night?"},
+ >>> messages2 = messages1 + [
+ ... {"role": "tool", "name": "multiply", "content": "6"},
... ]
>>> tokenizer.apply_chat_template(messages1, tokenize=False)
- '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n'
+ '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n'
- >>> tokenizer.apply_chat_template(messages2, tokenize=False)
- '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n'
+ >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
+ '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n'
- >>> # ^ think tags missing
+ >>> # ^ think tags missing
>>> chat_template = get_training_chat_template(tokenizer)
>>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template)
- '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n'
+ '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n'
- >>> tokenizer.apply_chat_template(messages2, tokenize=False, chat_template=chat_template)
- '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n'
+ >>> tokenizer.apply_chat_template(
+ ... messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template
+ ... )
+ '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n'
```
"""
# First check if patching is needed
@@ -822,10 +807,6 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None:
if tokenizer.chat_template == qwen3_chat_template:
return qwen3_training_chat_template
- if tokenizer.chat_template == qwen3_5_chat_template_2b_and_below:
- return qwen3_5_training_chat_template_2b_and_below
- if tokenizer.chat_template == qwen3_5_chat_template_4b_and_above:
- return qwen3_5_training_chat_template_4b_and_above
else:
raise ValueError(
"The tokenizer's chat template is not prefix-preserving and patching is not supported for this template. "
diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py
index 3d7350d5a7..df32627641 100644
--- a/trl/experimental/async_grpo/async_rollout_worker.py
+++ b/trl/experimental/async_grpo/async_rollout_worker.py
@@ -555,31 +555,41 @@ async def _generate_one(
def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int]:
"""Get token IDs for tool result formatting by using a minimal dummy conversation."""
+ dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}]
dummy_messages = [
- {"role": "user", "content": ""},
- {"role": "assistant", "content": ""},
+ {"role": "user", "content": "dummy"},
+ {
+ "role": "assistant",
+ # "content" is required here because VLM processors crash on tokenize=True without it
+ # (KeyError in processing_utils.py). See huggingface/transformers#45290.
+ "content": "",
+ "tool_calls": dummy_tool_calls,
+ },
]
prefix_ids = self.tokenizer.apply_chat_template(
dummy_messages,
- return_dict=False,
- tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[]
+ add_generation_prompt=False,
+ tokenize=True,
chat_template=self.chat_template,
+ return_dict=False,
**self.chat_template_kwargs,
)
full_ids = self.tokenizer.apply_chat_template(
dummy_messages + tool_messages,
- return_dict=False,
- chat_template=self.chat_template,
add_generation_prompt=True,
- tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[]
+ tokenize=True,
+ chat_template=self.chat_template,
+ return_dict=False,
**self.chat_template_kwargs,
)
# Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block.
# When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to
- # EOS (not EOS + newline).
- last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id)
- prefix_ids = prefix_ids[: last_eos_idx + 1]
+ # EOS (not EOS + newline). Templates that don't use EOS as end-of-turn (e.g. Gemma uses
+ # ) skip this trimming.
+ eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id]
+ if eos_positions:
+ prefix_ids = prefix_ids[: eos_positions[-1] + 1]
if full_ids[: len(prefix_ids)] != prefix_ids:
raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.")
diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py
index fd276b671f..e949181e20 100644
--- a/trl/trainer/grpo_trainer.py
+++ b/trl/trainer/grpo_trainer.py
@@ -1404,7 +1404,17 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
def _get_tool_suffix_ids(self, tool_messages):
"""Get token IDs for tool result formatting by using a minimal dummy conversation."""
- dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}]
+ dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}]
+ dummy_messages = [
+ {"role": "user", "content": "dummy"},
+ {
+ "role": "assistant",
+ # "content" is required here because VLM processors crash on tokenize=True without it
+ # (KeyError in processing_utils.py). See huggingface/transformers#45290.
+ "content": "",
+ "tool_calls": dummy_tool_calls,
+ },
+ ]
if self._is_vlm:
dummy_messages = prepare_multimodal_messages(dummy_messages)
prefix_ids = self.processing_class.apply_chat_template(