From 4b3aa51d70e6fa8dc65aa1d5ce94509e5e535e17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 18:44:21 +0000 Subject: [PATCH 1/9] Narrow prefix-preserving check to the actual requirement --- tests/test_chat_template_utils.py | 17 ++++----- trl/chat_template_utils.py | 57 +++++++++---------------------- 2 files changed, 23 insertions(+), 51 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 6459e3d8fef..2f6cecbb262 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -153,6 +153,8 @@ def test_prefix_preserving_template(self): {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} {%- elif message.role == 'assistant' %} {{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }} + {%- elif message.role == 'tool' %} + {{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} @@ -164,8 +166,9 @@ def test_prefix_preserving_template(self): def test_non_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") - # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part is - # only present for last assistant message, which makes it non-prefix-preserving. + # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part (even + # empty) is only present for last assistant message, which makes it non-prefix-preserving: appending a tool + # message changes the earlier output. # docstyle-ignore tokenizer.chat_template = textwrap.dedent(r""" {%- if messages[0].role == 'system' %} @@ -203,6 +206,8 @@ def test_non_prefix_preserving_template(self): {{- '<|im_start|>' + message.role + '\n' + content }} {%- endif %} {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} {%- if add_generation_prompt %} @@ -218,14 +223,6 @@ def test_non_prefix_preserving_template(self): "tokenizer_name", [ pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), - pytest.param( - "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", - id="qwen35", - marks=pytest.mark.skipif( - Version(transformers.__version__) < Version("5.0.0"), - reason="Qwen3.5 tokenizer requires transformers>=5.0.0", - ), - ), ], ) class TestGetTrainingChatTemplate: diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 90be3fa3b57..eeb3f74e794 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -630,6 +630,10 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ Check whether the chat template preserves prefixes when applied. + A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. + This property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by + comparing tokenizations with and without tool messages appended. + Args: tokenizer (`PreTrainedTokenizer`): Tokenizer instance to check. @@ -640,22 +644,18 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ messages1 = [ {"role": "user", "content": "What color is the sky?"}, - ] - messages2 = [ - {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ] - messages3 = [ + messages2 = [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, - {"role": "user", "content": "And at night?"}, + {"role": "tool", "name": "say_hi", "content": "Hi!"}, ] - text1 = tokenizer.apply_chat_template(messages1, tokenize=False, add_generation_prompt=True) - text2 = tokenizer.apply_chat_template(messages2, tokenize=False) - text3 = tokenizer.apply_chat_template(messages3, tokenize=False) + text1 = tokenizer.apply_chat_template(messages1, tokenize=False) + text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) - return text2.startswith(text1) and text3.startswith(text2) + return text2.startswith(text1) # Modifications: @@ -749,33 +749,12 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: {%- endif %}""" -# Modifications: -# - {%- if '' in content %} -# + {%- if '' in content and '' in content %} -# Always check for both tags to avoid edge cases where the model generates only one tag, which would otherwise be parsed incorrectly -# - {{- '<|im_start|>' + message.role + '\n' + content }} -# + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} -# Always include thinking block during training. It's important to have a prefix-preserving template. -def _patch_qwen3_5_training_template(template: str) -> str: - return template.replace( - "{%- if '' in content %}", - "{%- if '' in content and '' in content %}", - ).replace( - "{{- '<|im_start|>' + message.role + '\\n' + content }}", - "{{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content + '\\n\\n\\n' + content }}", - ) - - -qwen3_5_training_chat_template_2b_and_below = _patch_qwen3_5_training_template(qwen3_5_chat_template_2b_and_below) -qwen3_5_training_chat_template_4b_and_above = _patch_qwen3_5_training_template(qwen3_5_chat_template_4b_and_above) - - def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: r""" Get a prefix-preserving chat template for training, if needed. - If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 and - Qwen3.5 supported). Otherwise, returns `None`. + If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 + supported). Otherwise, returns `None`. Args: tokenizer (`PreTrainedTokenizer`): @@ -799,21 +778,21 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> messages2 = [ ... {"role": "user", "content": "What color is the sky?"}, ... {"role": "assistant", "content": "It is blue."}, - ... {"role": "user", "content": "And at night?"}, + ... {"role": "tool", "name": "say_hi", "content": "Hi!"}, ... ] >>> tokenizer.apply_chat_template(messages1, tokenize=False) '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' - >>> tokenizer.apply_chat_template(messages2, tokenize=False) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' + >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) + '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\n\nHi!\n<|im_end|>\n<|im_start|>assistant\n' >>> # ^ think tags missing >>> chat_template = get_training_chat_template(tokenizer) >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template) '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' - >>> tokenizer.apply_chat_template(messages2, tokenize=False, chat_template=chat_template) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' + >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template) + '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\n\nHi!\n<|im_end|>\n<|im_start|>assistant\n' ``` """ # First check if patching is needed @@ -822,10 +801,6 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: if tokenizer.chat_template == qwen3_chat_template: return qwen3_training_chat_template - if tokenizer.chat_template == qwen3_5_chat_template_2b_and_below: - return qwen3_5_training_chat_template_2b_and_below - if tokenizer.chat_template == qwen3_5_chat_template_4b_and_above: - return qwen3_5_training_chat_template_4b_and_above else: raise ValueError( "The tokenizer's chat template is not prefix-preserving and patching is not supported for this template. " From 730070bcf14f58ae60c9646ec379441d4d112c01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:00:02 +0000 Subject: [PATCH 2/9] Update chat template examples to use multiplication function calls --- trl/chat_template_utils.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index eeb3f74e794..21bbe598685 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -643,13 +643,13 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: `True` if the chat template preserves prefixes, `False` otherwise. """ messages1 = [ - {"role": "user", "content": "What color is the sky?"}, - {"role": "assistant", "content": "It is blue."}, + {"role": "user", "content": "What is 2 * 3?"}, + {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, ] messages2 = [ - {"role": "user", "content": "What color is the sky?"}, - {"role": "assistant", "content": "It is blue."}, - {"role": "tool", "name": "say_hi", "content": "Hi!"}, + {"role": "user", "content": "What is 2 * 3?"}, + {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, + {"role": "tool", "name": "multiply", "content": "6"}, ] text1 = tokenizer.apply_chat_template(messages1, tokenize=False) @@ -772,27 +772,25 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") >>> messages1 = [ - ... {"role": "user", "content": "What color is the sky?"}, - ... {"role": "assistant", "content": "It is blue."}, + ... {"role": "user", "content": "What is 2 * 3?"}, + ... {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, ... ] - >>> messages2 = [ - ... {"role": "user", "content": "What color is the sky?"}, - ... {"role": "assistant", "content": "It is blue."}, - ... {"role": "tool", "name": "say_hi", "content": "Hi!"}, + >>> messages2 = messages1 + [ + ... {"role": "tool", "name": "multiply", "content": "6"}, ... ] >>> tokenizer.apply_chat_template(messages1, tokenize=False) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n' >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\n\nHi!\n<|im_end|>\n<|im_start|>assistant\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n' - >>> # ^ think tags missing + >>> # ^ think tags missing >>> chat_template = get_training_chat_template(tokenizer) >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n' >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\n\nHi!\n<|im_end|>\n<|im_start|>assistant\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n' ``` """ # First check if patching is needed From 4622d77234bf46d010b782fcebab1b2ed2d2e32a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:01:42 +0000 Subject: [PATCH 3/9] style --- trl/chat_template_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 21bbe598685..6b9d3c24e2b 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -644,11 +644,19 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ messages1 = [ {"role": "user", "content": "What is 2 * 3?"}, - {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], + }, ] messages2 = [ {"role": "user", "content": "What is 2 * 3?"}, - {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], + }, {"role": "tool", "name": "multiply", "content": "6"}, ] From 8a00354e478502117c959f76a199873056947072 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:48:57 +0000 Subject: [PATCH 4/9] tools in dummy --- trl/chat_template_utils.py | 35 +++++++++++++++++------------------ trl/trainer/grpo_trainer.py | 8 +++++++- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 6b9d3c24e2b..f596720bacd 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -630,9 +630,9 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ Check whether the chat template preserves prefixes when applied. - A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. - This property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by - comparing tokenizations with and without tool messages appended. + A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. This + property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by comparing + tokenizations with and without tool messages appended. Args: tokenizer (`PreTrainedTokenizer`): @@ -642,22 +642,16 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: `bool`: `True` if the chat template preserves prefixes, `False` otherwise. """ + # Use the same dummy messages as _get_tool_suffix_ids to test the exact property it relies on. + dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] messages1 = [ - {"role": "user", "content": "What is 2 * 3?"}, - { - "role": "assistant", - "content": "", - "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], - }, + {"role": "user", "content": "dummy"}, + {"role": "assistant", "tool_calls": dummy_tool_calls}, ] messages2 = [ - {"role": "user", "content": "What is 2 * 3?"}, - { - "role": "assistant", - "content": "", - "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], - }, - {"role": "tool", "name": "multiply", "content": "6"}, + {"role": "user", "content": "dummy"}, + {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "tool", "name": "dummy", "content": "dummy"}, ] text1 = tokenizer.apply_chat_template(messages1, tokenize=False) @@ -781,7 +775,10 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") >>> messages1 = [ ... {"role": "user", "content": "What is 2 * 3?"}, - ... {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, + ... { + ... "role": "assistant", + ... "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], + ... }, ... ] >>> messages2 = messages1 + [ ... {"role": "tool", "name": "multiply", "content": "6"}, @@ -797,7 +794,9 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template) '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n' - >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template) + >>> tokenizer.apply_chat_template( + ... messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template + ... ) '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n' ``` """ diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0cd9c53aabe..d36e1237eb3 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1406,7 +1406,13 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): def _get_tool_suffix_ids(self, tool_messages): """Get token IDs for tool result formatting by using a minimal dummy conversation.""" - dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}] + dummy_messages = [ + {"role": "user", "content": "dummy"}, + { + "role": "assistant", + "tool_calls": [{"type": "function", "function": {"name": "dummy", "arguments": {}}}], + }, + ] if self._is_vlm: dummy_messages = prepare_multimodal_messages(dummy_messages) prefix_ids = self.processing_class.apply_chat_template( From cd8cbfc881340c18e00fc396ff059695c4291ea9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 20:10:46 +0000 Subject: [PATCH 5/9] Enhance chat template handling to include tool call formatting in messages + async grpo --- tests/test_chat_template_utils.py | 24 ++++++++++++++++--- .../async_grpo/async_rollout_worker.py | 5 ++-- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 2f6cecbb262..b65231fb76c 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -148,13 +148,23 @@ def test_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") tokenizer.chat_template = textwrap.dedent(r""" {%- for message in messages %} + {%- set content = message.content if message.content is string else '' %} {%- if message.role == 'user' %} - {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} + {{- '<|im_start|>user\n' + content + '<|im_end|>\n' }} {%- elif message.role == 'assistant' %} - {{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }} + {{- '<|im_start|>assistant\n' + content }} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '' + tool_call.name + '' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} {%- elif message.role == 'tool' %} - {{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }} + {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} @@ -205,6 +215,14 @@ def test_non_prefix_preserving_template(self): {%- else %} {{- '<|im_start|>' + message.role + '\n' + content }} {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '' + tool_call.name + '' }} + {%- endfor %} + {%- endif %} {{- '<|im_end|>\n' }} {%- elif message.role == "tool" %} {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 3d7350d5a71..d4b64860692 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -555,9 +555,10 @@ async def _generate_one( def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int]: """Get token IDs for tool result formatting by using a minimal dummy conversation.""" + dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ - {"role": "user", "content": ""}, - {"role": "assistant", "content": ""}, + {"role": "user", "content": "dummy"}, + {"role": "assistant", "content": "dummy", "tool_calls": dummy_tool_calls}, ] prefix_ids = self.tokenizer.apply_chat_template( dummy_messages, From 8b35321abf87fc658b3d119507e875d1657d2691 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 20:18:59 +0000 Subject: [PATCH 6/9] align grpo and async --- .../async_grpo/async_rollout_worker.py | 21 +++++++++++-------- trl/trainer/grpo_trainer.py | 6 ++---- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index d4b64860692..2d0e4b2b6c7 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -558,29 +558,32 @@ def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int] dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "content": "dummy", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "tool_calls": dummy_tool_calls}, ] prefix_ids = self.tokenizer.apply_chat_template( dummy_messages, - return_dict=False, - tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] + add_generation_prompt=False, + tokenize=True, chat_template=self.chat_template, + return_dict=False, **self.chat_template_kwargs, ) full_ids = self.tokenizer.apply_chat_template( dummy_messages + tool_messages, - return_dict=False, - chat_template=self.chat_template, add_generation_prompt=True, - tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] + tokenize=True, + chat_template=self.chat_template, + return_dict=False, **self.chat_template_kwargs, ) # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block. # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to - # EOS (not EOS + newline). - last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id) - prefix_ids = prefix_ids[: last_eos_idx + 1] + # EOS (not EOS + newline). Templates that don't use EOS as end-of-turn (e.g. Gemma uses + # ) skip this trimming. + eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id] + if eos_positions: + prefix_ids = prefix_ids[: eos_positions[-1] + 1] if full_ids[: len(prefix_ids)] != prefix_ids: raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d36e1237eb3..0d172190777 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1406,12 +1406,10 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): def _get_tool_suffix_ids(self, tool_messages): """Get token IDs for tool result formatting by using a minimal dummy conversation.""" + dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - { - "role": "assistant", - "tool_calls": [{"type": "function", "function": {"name": "dummy", "arguments": {}}}], - }, + {"role": "assistant", "tool_calls": dummy_tool_calls}, ] if self._is_vlm: dummy_messages = prepare_multimodal_messages(dummy_messages) From 8dd341b0c6d3854962f8f997ab20d7e34095d721 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Apr 2026 14:04:54 +0000 Subject: [PATCH 7/9] revert no content --- tests/test_chat_template_utils.py | 9 ++++----- trl/chat_template_utils.py | 5 +++-- trl/experimental/async_grpo/async_rollout_worker.py | 2 +- trl/trainer/grpo_trainer.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index b65231fb76c..1d1840c623a 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -148,12 +148,11 @@ def test_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") tokenizer.chat_template = textwrap.dedent(r""" {%- for message in messages %} - {%- set content = message.content if message.content is string else '' %} {%- if message.role == 'user' %} - {{- '<|im_start|>user\n' + content + '<|im_end|>\n' }} + {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} {%- elif message.role == 'assistant' %} - {{- '<|im_start|>assistant\n' + content }} + {{- '<|im_start|>assistant\n' + message.content }} {%- if message.tool_calls %} {%- for tool_call in message.tool_calls %} {%- if tool_call.function %} @@ -164,7 +163,7 @@ def test_prefix_preserving_template(self): {%- endif %} {{- '<|im_end|>\n' }} {%- elif message.role == 'tool' %} - {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} + {{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} @@ -471,7 +470,7 @@ def test_parse_response_tool_call_without_arguments(self, tokenizer_name): tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}] messages = [ {"role": "user", "content": "Ping the service."}, - {"role": "assistant", "tool_calls": tool_calls}, + {"role": "assistant", "content": "", "tool_calls": tool_calls}, ] prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids text = tokenizer.apply_chat_template(messages).input_ids diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index f596720bacd..4c3d0c8eb06 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -646,11 +646,11 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] messages1 = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, ] messages2 = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, {"role": "tool", "name": "dummy", "content": "dummy"}, ] @@ -777,6 +777,7 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: ... {"role": "user", "content": "What is 2 * 3?"}, ... { ... "role": "assistant", + ... "content": "", ... "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], ... }, ... ] diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 2d0e4b2b6c7..1a63a7d8e90 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -558,7 +558,7 @@ def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int] dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, ] prefix_ids = self.tokenizer.apply_chat_template( dummy_messages, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0d172190777..aa64eec2b5a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1409,7 +1409,7 @@ def _get_tool_suffix_ids(self, tool_messages): dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, ] if self._is_vlm: dummy_messages = prepare_multimodal_messages(dummy_messages) From 7e3ddd77f791334af8963437f72ed1063a4b5e7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Apr 2026 14:06:33 +0000 Subject: [PATCH 8/9] docstyle ignore --- tests/test_chat_template_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 1d1840c623a..50eb8af4aad 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -146,6 +146,7 @@ def test_add_response_schema(self, tokenizer_name): class TestIsChatTemplatePrefixPreserving: def test_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + # docstyle-ignore tokenizer.chat_template = textwrap.dedent(r""" {%- for message in messages %} From 87131e503f85bdc1ea65947f6937da2b5635c6fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 7 Apr 2026 14:08:06 +0000 Subject: [PATCH 9/9] handle content="" --- tests/test_chat_template_utils.py | 2 +- trl/experimental/async_grpo/async_rollout_worker.py | 8 +++++++- trl/trainer/grpo_trainer.py | 8 +++++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 155bf77d9b4..2cdd45b85fb 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -490,7 +490,7 @@ def test_parse_response_tool_call_without_arguments(self, tokenizer_name): tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}] messages = [ {"role": "user", "content": "Ping the service."}, - {"role": "assistant", "content": "", "tool_calls": tool_calls}, + {"role": "assistant", "tool_calls": tool_calls}, ] prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids text = tokenizer.apply_chat_template(messages).input_ids diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 1a63a7d8e90..df32627641a 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -558,7 +558,13 @@ def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int] dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, + { + "role": "assistant", + # "content" is required here because VLM processors crash on tokenize=True without it + # (KeyError in processing_utils.py). See huggingface/transformers#45290. + "content": "", + "tool_calls": dummy_tool_calls, + }, ] prefix_ids = self.tokenizer.apply_chat_template( dummy_messages, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ca08d2bd4fa..e949181e20d 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1407,7 +1407,13 @@ def _get_tool_suffix_ids(self, tool_messages): dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, + { + "role": "assistant", + # "content" is required here because VLM processors crash on tokenize=True without it + # (KeyError in processing_utils.py). See huggingface/transformers#45290. + "content": "", + "tool_calls": dummy_tool_calls, + }, ] if self._is_vlm: dummy_messages = prepare_multimodal_messages(dummy_messages)