From 4b3aa51d70e6fa8dc65aa1d5ce94509e5e535e17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 18:44:21 +0000 Subject: [PATCH 01/12] Narrow prefix-preserving check to the actual requirement --- tests/test_chat_template_utils.py | 17 ++++----- trl/chat_template_utils.py | 57 +++++++++---------------------- 2 files changed, 23 insertions(+), 51 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 6459e3d8fef..2f6cecbb262 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -153,6 +153,8 @@ def test_prefix_preserving_template(self): {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} {%- elif message.role == 'assistant' %} {{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }} + {%- elif message.role == 'tool' %} + {{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} @@ -164,8 +166,9 @@ def test_prefix_preserving_template(self): def test_non_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") - # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part is - # only present for last assistant message, which makes it non-prefix-preserving. + # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part (even + # empty) is only present for last assistant message, which makes it non-prefix-preserving: appending a tool + # message changes the earlier output. # docstyle-ignore tokenizer.chat_template = textwrap.dedent(r""" {%- if messages[0].role == 'system' %} @@ -203,6 +206,8 @@ def test_non_prefix_preserving_template(self): {{- '<|im_start|>' + message.role + '\n' + content }} {%- endif %} {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} {%- if add_generation_prompt %} @@ -218,14 +223,6 @@ def test_non_prefix_preserving_template(self): "tokenizer_name", [ pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), - pytest.param( - "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", - id="qwen35", - marks=pytest.mark.skipif( - Version(transformers.__version__) < Version("5.0.0"), - reason="Qwen3.5 tokenizer requires transformers>=5.0.0", - ), - ), ], ) class TestGetTrainingChatTemplate: diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 90be3fa3b57..eeb3f74e794 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -630,6 +630,10 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ Check whether the chat template preserves prefixes when applied. + A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. + This property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by + comparing tokenizations with and without tool messages appended. + Args: tokenizer (`PreTrainedTokenizer`): Tokenizer instance to check. @@ -640,22 +644,18 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ messages1 = [ {"role": "user", "content": "What color is the sky?"}, - ] - messages2 = [ - {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ] - messages3 = [ + messages2 = [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, - {"role": "user", "content": "And at night?"}, + {"role": "tool", "name": "say_hi", "content": "Hi!"}, ] - text1 = tokenizer.apply_chat_template(messages1, tokenize=False, add_generation_prompt=True) - text2 = tokenizer.apply_chat_template(messages2, tokenize=False) - text3 = tokenizer.apply_chat_template(messages3, tokenize=False) + text1 = tokenizer.apply_chat_template(messages1, tokenize=False) + text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) - return text2.startswith(text1) and text3.startswith(text2) + return text2.startswith(text1) # Modifications: @@ -749,33 +749,12 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: {%- endif %}""" -# Modifications: -# - {%- if '' in content %} -# + {%- if '' in content and '' in content %} -# Always check for both tags to avoid edge cases where the model generates only one tag, which would otherwise be parsed incorrectly -# - {{- '<|im_start|>' + message.role + '\n' + content }} -# + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} -# Always include thinking block during training. It's important to have a prefix-preserving template. -def _patch_qwen3_5_training_template(template: str) -> str: - return template.replace( - "{%- if '' in content %}", - "{%- if '' in content and '' in content %}", - ).replace( - "{{- '<|im_start|>' + message.role + '\\n' + content }}", - "{{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content + '\\n\\n\\n' + content }}", - ) - - -qwen3_5_training_chat_template_2b_and_below = _patch_qwen3_5_training_template(qwen3_5_chat_template_2b_and_below) -qwen3_5_training_chat_template_4b_and_above = _patch_qwen3_5_training_template(qwen3_5_chat_template_4b_and_above) - - def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: r""" Get a prefix-preserving chat template for training, if needed. - If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 and - Qwen3.5 supported). Otherwise, returns `None`. + If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 + supported). Otherwise, returns `None`. Args: tokenizer (`PreTrainedTokenizer`): @@ -799,21 +778,21 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> messages2 = [ ... {"role": "user", "content": "What color is the sky?"}, ... {"role": "assistant", "content": "It is blue."}, - ... {"role": "user", "content": "And at night?"}, + ... {"role": "tool", "name": "say_hi", "content": "Hi!"}, ... ] >>> tokenizer.apply_chat_template(messages1, tokenize=False) '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' - >>> tokenizer.apply_chat_template(messages2, tokenize=False) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' + >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) + '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\n\nHi!\n<|im_end|>\n<|im_start|>assistant\n' >>> # ^ think tags missing >>> chat_template = get_training_chat_template(tokenizer) >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template) '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' - >>> tokenizer.apply_chat_template(messages2, tokenize=False, chat_template=chat_template) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' + >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template) + '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\n\nHi!\n<|im_end|>\n<|im_start|>assistant\n' ``` """ # First check if patching is needed @@ -822,10 +801,6 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: if tokenizer.chat_template == qwen3_chat_template: return qwen3_training_chat_template - if tokenizer.chat_template == qwen3_5_chat_template_2b_and_below: - return qwen3_5_training_chat_template_2b_and_below - if tokenizer.chat_template == qwen3_5_chat_template_4b_and_above: - return qwen3_5_training_chat_template_4b_and_above else: raise ValueError( "The tokenizer's chat template is not prefix-preserving and patching is not supported for this template. " From 730070bcf14f58ae60c9646ec379441d4d112c01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:00:02 +0000 Subject: [PATCH 02/12] Update chat template examples to use multiplication function calls --- trl/chat_template_utils.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index eeb3f74e794..21bbe598685 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -643,13 +643,13 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: `True` if the chat template preserves prefixes, `False` otherwise. """ messages1 = [ - {"role": "user", "content": "What color is the sky?"}, - {"role": "assistant", "content": "It is blue."}, + {"role": "user", "content": "What is 2 * 3?"}, + {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, ] messages2 = [ - {"role": "user", "content": "What color is the sky?"}, - {"role": "assistant", "content": "It is blue."}, - {"role": "tool", "name": "say_hi", "content": "Hi!"}, + {"role": "user", "content": "What is 2 * 3?"}, + {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, + {"role": "tool", "name": "multiply", "content": "6"}, ] text1 = tokenizer.apply_chat_template(messages1, tokenize=False) @@ -772,27 +772,25 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") >>> messages1 = [ - ... {"role": "user", "content": "What color is the sky?"}, - ... {"role": "assistant", "content": "It is blue."}, + ... {"role": "user", "content": "What is 2 * 3?"}, + ... {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, ... ] - >>> messages2 = [ - ... {"role": "user", "content": "What color is the sky?"}, - ... {"role": "assistant", "content": "It is blue."}, - ... {"role": "tool", "name": "say_hi", "content": "Hi!"}, + >>> messages2 = messages1 + [ + ... {"role": "tool", "name": "multiply", "content": "6"}, ... ] >>> tokenizer.apply_chat_template(messages1, tokenize=False) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n' >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\n\nHi!\n<|im_end|>\n<|im_start|>assistant\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n' - >>> # ^ think tags missing + >>> # ^ think tags missing >>> chat_template = get_training_chat_template(tokenizer) >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n' >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\n\nHi!\n<|im_end|>\n<|im_start|>assistant\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n' ``` """ # First check if patching is needed From 4622d77234bf46d010b782fcebab1b2ed2d2e32a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:01:42 +0000 Subject: [PATCH 03/12] style --- trl/chat_template_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 21bbe598685..6b9d3c24e2b 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -644,11 +644,19 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ messages1 = [ {"role": "user", "content": "What is 2 * 3?"}, - {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], + }, ] messages2 = [ {"role": "user", "content": "What is 2 * 3?"}, - {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], + }, {"role": "tool", "name": "multiply", "content": "6"}, ] From 08d4c518ec29b2c9d3e0f2fff5eae90ca9aebdf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:33:46 +0000 Subject: [PATCH 04/12] Move chat templates from inline strings to `.jinja` files --- trl/chat_template_utils.py | 503 +----------------- trl/chat_templates/README.md | 50 ++ trl/chat_templates/qwen3.jinja | 89 ++++ trl/chat_templates/qwen3_5_2b_and_below.jinja | 154 ++++++ trl/chat_templates/qwen3_5_4b_and_above.jinja | 154 ++++++ trl/chat_templates/qwen3_training.jinja | 88 +++ 6 files changed, 546 insertions(+), 492 deletions(-) create mode 100644 trl/chat_templates/README.md create mode 100644 trl/chat_templates/qwen3.jinja create mode 100644 trl/chat_templates/qwen3_5_2b_and_below.jinja create mode 100644 trl/chat_templates/qwen3_5_4b_and_above.jinja create mode 100644 trl/chat_templates/qwen3_training.jinja diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 6b9d3c24e2b..3e86874816a 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + from transformers import AddedToken, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer +_CHAT_TEMPLATES_DIR = Path(__file__).parent / "chat_templates" + def clone_chat_template( model: PreTrainedModel, @@ -178,409 +182,12 @@ def clone_chat_template( }, } -# docstyle-ignore -qwen3_chat_template = r"""{%- if tools %} - {{- '<|im_start|>system\n' }} - {%- if messages[0].role == 'system' %} - {{- messages[0].content + '\n\n' }} - {%- endif %} - {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} - {%- for tool in tools %} - {{- "\n" }} - {{- tool | tojson }} - {%- endfor %} - {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} -{%- else %} - {%- if messages[0].role == 'system' %} - {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} - {%- endif %} -{%- endif %} -{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} -{%- for message in messages[::-1] %} - {%- set index = (messages|length - 1) - loop.index0 %} - {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} - {%- set ns.multi_step_tool = false %} - {%- set ns.last_query_index = index %} - {%- endif %} -{%- endfor %} -{%- for message in messages %} - {%- if message.content is string %} - {%- set content = message.content %} - {%- else %} - {%- set content = '' %} - {%- endif %} - {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} - {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} - {%- elif message.role == "assistant" %} - {%- set reasoning_content = '' %} - {%- if message.reasoning_content is string %} - {%- set reasoning_content = message.reasoning_content %} - {%- else %} - {%- if '' in content %} - {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} - {%- set content = content.split('')[-1].lstrip('\n') %} - {%- endif %} - {%- endif %} - {%- if loop.index0 > ns.last_query_index %} - {%- if loop.last or (not loop.last and reasoning_content) %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} - {%- else %} - {{- '<|im_start|>' + message.role + '\n' + content }} - {%- endif %} - {%- else %} - {{- '<|im_start|>' + message.role + '\n' + content }} - {%- endif %} - {%- if message.tool_calls %} - {%- for tool_call in message.tool_calls %} - {%- if (loop.first and content) or (not loop.first) %} - {{- '\n' }} - {%- endif %} - {%- if tool_call.function %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {{- '\n{"name": "' }} - {{- tool_call.name }} - {{- '", "arguments": ' }} - {%- if tool_call.arguments is string %} - {{- tool_call.arguments }} - {%- else %} - {{- tool_call.arguments | tojson }} - {%- endif %} - {{- '}\n' }} - {%- endfor %} - {%- endif %} - {{- '<|im_end|>\n' }} - {%- elif message.role == "tool" %} - {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} - {{- '<|im_start|>user' }} - {%- endif %} - {{- '\n\n' }} - {{- content }} - {{- '\n' }} - {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} - {{- '<|im_end|>\n' }} - {%- endif %} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|im_start|>assistant\n' }} - {%- if enable_thinking is defined and enable_thinking is false %} - {{- '\n\n\n\n' }} - {%- endif %} -{%- endif %}""" - -# docstyle-ignore -qwen3_5_chat_template_2b_and_below = r"""{%- set image_count = namespace(value=0) %} -{%- set video_count = namespace(value=0) %} -{%- macro render_content(content, do_vision_count, is_system_content=false) %} - {%- if content is string %} - {{- content }} - {%- elif content is iterable and content is not mapping %} - {%- for item in content %} - {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} - {%- if is_system_content %} - {{- raise_exception('System message cannot contain images.') }} - {%- endif %} - {%- if do_vision_count %} - {%- set image_count.value = image_count.value + 1 %} - {%- endif %} - {%- if add_vision_id %} - {{- 'Picture ' ~ image_count.value ~ ': ' }} - {%- endif %} - {{- '<|vision_start|><|image_pad|><|vision_end|>' }} - {%- elif 'video' in item or item.type == 'video' %} - {%- if is_system_content %} - {{- raise_exception('System message cannot contain videos.') }} - {%- endif %} - {%- if do_vision_count %} - {%- set video_count.value = video_count.value + 1 %} - {%- endif %} - {%- if add_vision_id %} - {{- 'Video ' ~ video_count.value ~ ': ' }} - {%- endif %} - {{- '<|vision_start|><|video_pad|><|vision_end|>' }} - {%- elif 'text' in item %} - {{- item.text }} - {%- else %} - {{- raise_exception('Unexpected item type in content.') }} - {%- endif %} - {%- endfor %} - {%- elif content is none or content is undefined %} - {{- '' }} - {%- else %} - {{- raise_exception('Unexpected content type.') }} - {%- endif %} -{%- endmacro %} -{%- if not messages %} - {{- raise_exception('No messages provided.') }} -{%- endif %} -{%- if tools and tools is iterable and tools is not mapping %} - {{- '<|im_start|>system\n' }} - {{- "# Tools\n\nYou have access to the following functions:\n\n" }} - {%- for tool in tools %} - {{- "\n" }} - {{- tool | tojson }} - {%- endfor %} - {{- "\n" }} - {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} - {%- if messages[0].role == 'system' %} - {%- set content = render_content(messages[0].content, false, true)|trim %} - {%- if content %} - {{- '\n\n' + content }} - {%- endif %} - {%- endif %} - {{- '<|im_end|>\n' }} -{%- else %} - {%- if messages[0].role == 'system' %} - {%- set content = render_content(messages[0].content, false, true)|trim %} - {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} - {%- endif %} -{%- endif %} -{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} -{%- for message in messages[::-1] %} - {%- set index = (messages|length - 1) - loop.index0 %} - {%- if ns.multi_step_tool and message.role == "user" %} - {%- set content = render_content(message.content, false)|trim %} - {%- if not(content.startswith('') and content.endswith('')) %} - {%- set ns.multi_step_tool = false %} - {%- set ns.last_query_index = index %} - {%- endif %} - {%- endif %} -{%- endfor %} -{%- if ns.multi_step_tool %} - {{- raise_exception('No user query found in messages.') }} -{%- endif %} -{%- for message in messages %} - {%- set content = render_content(message.content, true)|trim %} - {%- if message.role == "system" %} - {%- if not loop.first %} - {{- raise_exception('System message must be at the beginning.') }} - {%- endif %} - {%- elif message.role == "user" %} - {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} - {%- elif message.role == "assistant" %} - {%- set reasoning_content = '' %} - {%- if message.reasoning_content is string %} - {%- set reasoning_content = message.reasoning_content %} - {%- else %} - {%- if '' in content %} - {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} - {%- set content = content.split('')[-1].lstrip('\n') %} - {%- endif %} - {%- endif %} - {%- set reasoning_content = reasoning_content|trim %} - {%- if loop.index0 > ns.last_query_index %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} - {%- else %} - {{- '<|im_start|>' + message.role + '\n' + content }} - {%- endif %} - {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} - {%- for tool_call in message.tool_calls %} - {%- if tool_call.function is defined %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {%- if loop.first %} - {%- if content|trim %} - {{- '\n\n\n\n' }} - {%- else %} - {{- '\n\n' }} - {%- endif %} - {%- else %} - {{- '\n\n\n' }} - {%- endif %} - {%- if tool_call.arguments is defined %} - {%- for args_name, args_value in tool_call.arguments|items %} - {{- '\n' }} - {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} - {{- args_value }} - {{- '\n\n' }} - {%- endfor %} - {%- endif %} - {{- '\n' }} - {%- endfor %} - {%- endif %} - {{- '<|im_end|>\n' }} - {%- elif message.role == "tool" %} - {%- if loop.previtem and loop.previtem.role != "tool" %} - {{- '<|im_start|>user' }} - {%- endif %} - {{- '\n\n' }} - {{- content }} - {{- '\n' }} - {%- if not loop.last and loop.nextitem.role != "tool" %} - {{- '<|im_end|>\n' }} - {%- elif loop.last %} - {{- '<|im_end|>\n' }} - {%- endif %} - {%- else %} - {{- raise_exception('Unexpected message role.') }} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|im_start|>assistant\n' }} - {%- if enable_thinking is defined and enable_thinking is true %} - {{- '\n' }} - {%- else %} - {{- '\n\n\n\n' }} - {%- endif %} -{%- endif %}""" - - -# docstyle-ignore -qwen3_5_chat_template_4b_and_above = r"""{%- set image_count = namespace(value=0) %} -{%- set video_count = namespace(value=0) %} -{%- macro render_content(content, do_vision_count, is_system_content=false) %} - {%- if content is string %} - {{- content }} - {%- elif content is iterable and content is not mapping %} - {%- for item in content %} - {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} - {%- if is_system_content %} - {{- raise_exception('System message cannot contain images.') }} - {%- endif %} - {%- if do_vision_count %} - {%- set image_count.value = image_count.value + 1 %} - {%- endif %} - {%- if add_vision_id %} - {{- 'Picture ' ~ image_count.value ~ ': ' }} - {%- endif %} - {{- '<|vision_start|><|image_pad|><|vision_end|>' }} - {%- elif 'video' in item or item.type == 'video' %} - {%- if is_system_content %} - {{- raise_exception('System message cannot contain videos.') }} - {%- endif %} - {%- if do_vision_count %} - {%- set video_count.value = video_count.value + 1 %} - {%- endif %} - {%- if add_vision_id %} - {{- 'Video ' ~ video_count.value ~ ': ' }} - {%- endif %} - {{- '<|vision_start|><|video_pad|><|vision_end|>' }} - {%- elif 'text' in item %} - {{- item.text }} - {%- else %} - {{- raise_exception('Unexpected item type in content.') }} - {%- endif %} - {%- endfor %} - {%- elif content is none or content is undefined %} - {{- '' }} - {%- else %} - {{- raise_exception('Unexpected content type.') }} - {%- endif %} -{%- endmacro %} -{%- if not messages %} - {{- raise_exception('No messages provided.') }} -{%- endif %} -{%- if tools and tools is iterable and tools is not mapping %} - {{- '<|im_start|>system\n' }} - {{- "# Tools\n\nYou have access to the following functions:\n\n" }} - {%- for tool in tools %} - {{- "\n" }} - {{- tool | tojson }} - {%- endfor %} - {{- "\n" }} - {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} - {%- if messages[0].role == 'system' %} - {%- set content = render_content(messages[0].content, false, true)|trim %} - {%- if content %} - {{- '\n\n' + content }} - {%- endif %} - {%- endif %} - {{- '<|im_end|>\n' }} -{%- else %} - {%- if messages[0].role == 'system' %} - {%- set content = render_content(messages[0].content, false, true)|trim %} - {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} - {%- endif %} -{%- endif %} -{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} -{%- for message in messages[::-1] %} - {%- set index = (messages|length - 1) - loop.index0 %} - {%- if ns.multi_step_tool and message.role == "user" %} - {%- set content = render_content(message.content, false)|trim %} - {%- if not(content.startswith('') and content.endswith('')) %} - {%- set ns.multi_step_tool = false %} - {%- set ns.last_query_index = index %} - {%- endif %} - {%- endif %} -{%- endfor %} -{%- if ns.multi_step_tool %} - {{- raise_exception('No user query found in messages.') }} -{%- endif %} -{%- for message in messages %} - {%- set content = render_content(message.content, true)|trim %} - {%- if message.role == "system" %} - {%- if not loop.first %} - {{- raise_exception('System message must be at the beginning.') }} - {%- endif %} - {%- elif message.role == "user" %} - {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} - {%- elif message.role == "assistant" %} - {%- set reasoning_content = '' %} - {%- if message.reasoning_content is string %} - {%- set reasoning_content = message.reasoning_content %} - {%- else %} - {%- if '' in content %} - {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} - {%- set content = content.split('')[-1].lstrip('\n') %} - {%- endif %} - {%- endif %} - {%- set reasoning_content = reasoning_content|trim %} - {%- if loop.index0 > ns.last_query_index %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} - {%- else %} - {{- '<|im_start|>' + message.role + '\n' + content }} - {%- endif %} - {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} - {%- for tool_call in message.tool_calls %} - {%- if tool_call.function is defined %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {%- if loop.first %} - {%- if content|trim %} - {{- '\n\n\n\n' }} - {%- else %} - {{- '\n\n' }} - {%- endif %} - {%- else %} - {{- '\n\n\n' }} - {%- endif %} - {%- if tool_call.arguments is defined %} - {%- for args_name, args_value in tool_call.arguments|items %} - {{- '\n' }} - {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} - {{- args_value }} - {{- '\n\n' }} - {%- endfor %} - {%- endif %} - {{- '\n' }} - {%- endfor %} - {%- endif %} - {{- '<|im_end|>\n' }} - {%- elif message.role == "tool" %} - {%- if loop.previtem and loop.previtem.role != "tool" %} - {{- '<|im_start|>user' }} - {%- endif %} - {{- '\n\n' }} - {{- content }} - {{- '\n' }} - {%- if not loop.last and loop.nextitem.role != "tool" %} - {{- '<|im_end|>\n' }} - {%- elif loop.last %} - {{- '<|im_end|>\n' }} - {%- endif %} - {%- else %} - {{- raise_exception('Unexpected message role.') }} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|im_start|>assistant\n' }} - {%- if enable_thinking is defined and enable_thinking is false %} - {{- '\n\n\n\n' }} - {%- else %} - {{- '\n' }} - {%- endif %} -{%- endif %}""" +qwen3_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3.jinja").read_text() + +qwen3_5_chat_template_2b_and_below = (_CHAT_TEMPLATES_DIR / "qwen3_5_2b_and_below.jinja").read_text() + + +qwen3_5_chat_template_4b_and_above = (_CHAT_TEMPLATES_DIR / "qwen3_5_4b_and_above.jinja").read_text() def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: @@ -666,95 +273,7 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: return text2.startswith(text1) -# Modifications: -# - {%- if '' in content %} -# + {%- if '' in content and '' in content %} -# Always check for both tags to avoid edge cases where the model generates only one tag, which would otherwise be parsed incorrectly -# - {%- if loop.index0 > ns.last_query_index %} ... {%- endif %} -# + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} -# Always include thinking block during training. It's important to have a prefix-preserving template. -# docstyle-ignore -qwen3_training_chat_template = r"""{%- if tools %} - {{- '<|im_start|>system\n' }} - {%- if messages[0].role == 'system' %} - {{- messages[0].content + '\n\n' }} - {%- endif %} - {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} - {%- for tool in tools %} - {{- "\n" }} - {{- tool | tojson }} - {%- endfor %} - {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} -{%- else %} - {%- if messages[0].role == 'system' %} - {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} - {%- endif %} -{%- endif %} -{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} -{%- for message in messages[::-1] %} - {%- set index = (messages|length - 1) - loop.index0 %} - {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} - {%- set ns.multi_step_tool = false %} - {%- set ns.last_query_index = index %} - {%- endif %} -{%- endfor %} -{%- for message in messages %} - {%- if message.content is string %} - {%- set content = message.content %} - {%- else %} - {%- set content = '' %} - {%- endif %} - {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} - {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} - {%- elif message.role == "assistant" %} - {%- set reasoning_content = '' %} - {%- if message.reasoning_content is string %} - {%- set reasoning_content = message.reasoning_content %} - {%- else %} - {%- if '' in content and '' in content %} - {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} - {%- set content = content.split('')[-1].lstrip('\n') %} - {%- endif %} - {%- endif %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} - {%- if message.tool_calls %} - {%- for tool_call in message.tool_calls %} - {%- if (loop.first and content) or (not loop.first) %} - {{- '\n' }} - {%- endif %} - {%- if tool_call.function %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {{- '\n{"name": "' }} - {{- tool_call.name }} - {{- '", "arguments": ' }} - {%- if tool_call.arguments is string %} - {{- tool_call.arguments }} - {%- else %} - {{- tool_call.arguments | tojson }} - {%- endif %} - {{- '}\n' }} - {%- endfor %} - {%- endif %} - {{- '<|im_end|>\n' }} - {%- elif message.role == "tool" %} - {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} - {{- '<|im_start|>user' }} - {%- endif %} - {{- '\n\n' }} - {{- content }} - {{- '\n' }} - {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} - {{- '<|im_end|>\n' }} - {%- endif %} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|im_start|>assistant\n' }} - {%- if enable_thinking is defined and enable_thinking is false %} - {{- '\n\n\n\n' }} - {%- endif %} -{%- endif %}""" +qwen3_training_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3_training.jinja").read_text() def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: diff --git a/trl/chat_templates/README.md b/trl/chat_templates/README.md new file mode 100644 index 00000000000..5eb01cc4aca --- /dev/null +++ b/trl/chat_templates/README.md @@ -0,0 +1,50 @@ +# Chat Templates + +Jinja2 chat templates stored here serve two purposes: + +1. **Identity comparison**: detecting which model is being used (by comparing `tokenizer.chat_template` against known templates) to add the appropriate response schema (`add_response_schema`) or swap in a training template (`get_training_chat_template`). +2. **Training patches**: modified templates that fix prefix-preservation issues for the GRPO tool call loop. + +**Why prefix-preserving?** The GRPO tool call loop extracts tool response formatting tokens by comparing tokenizations with and without tool messages appended (`_get_tool_suffix_ids`). This requires the chat template to be *prefix-preserving*: appending messages must not change how earlier messages are rendered. + +## Original templates + +Used for identity comparison only. + +### `qwen3.jinja` + +Original Qwen3 chat template. + +### `qwen3_5_2b_and_below.jinja` / `qwen3_5_4b_and_above.jinja` + +Original Qwen3.5 chat templates. + +## Training templates + +Patched templates that fix prefix-preservation issues. Swapped in at init when tools are enabled. + +### `qwen3_training.jinja` + +Patched Qwen3 template that always includes thinking blocks, making it prefix-preserving. Diff vs `qwen3.jinja`: + +Require both `` and `` to be present before parsing, to avoid incorrect splitting when the model generates only one tag: + +```diff +- {%- if '' in content %} ++ {%- if '' in content and '' in content %} +``` + +Always include the thinking block regardless of message position. The original conditionally omits it based on `loop.last`, which changes the assistant rendering when a tool message is appended — breaking prefix-preservation: + +```diff +- {%- if loop.index0 > ns.last_query_index %} +- {%- if loop.last or (not loop.last and reasoning_content) %} +- {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} +- {%- else %} +- {{- '<|im_start|>' + message.role + '\n' + content }} +- {%- endif %} +- {%- else %} +- {{- '<|im_start|>' + message.role + '\n' + content }} +- {%- endif %} ++ {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} +``` diff --git a/trl/chat_templates/qwen3.jinja b/trl/chat_templates/qwen3.jinja new file mode 100644 index 00000000000..01be9b307da --- /dev/null +++ b/trl/chat_templates/qwen3.jinja @@ -0,0 +1,89 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/trl/chat_templates/qwen3_5_2b_and_below.jinja b/trl/chat_templates/qwen3_5_2b_and_below.jinja new file mode 100644 index 00000000000..0ef09f214ea --- /dev/null +++ b/trl/chat_templates/qwen3_5_2b_and_below.jinja @@ -0,0 +1,154 @@ +{%- set image_count = namespace(value=0) %} +{%- set video_count = namespace(value=0) %} +{%- macro render_content(content, do_vision_count, is_system_content=false) %} + {%- if content is string %} + {{- content }} + {%- elif content is iterable and content is not mapping %} + {%- for item in content %} + {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain images.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set image_count.value = image_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Picture ' ~ image_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- elif 'video' in item or item.type == 'video' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain videos.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set video_count.value = video_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Video ' ~ video_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|video_pad|><|vision_end|>' }} + {%- elif 'text' in item %} + {{- item.text }} + {%- else %} + {{- raise_exception('Unexpected item type in content.') }} + {%- endif %} + {%- endfor %} + {%- elif content is none or content is undefined %} + {{- '' }} + {%- else %} + {{- raise_exception('Unexpected content type.') }} + {%- endif %} +{%- endmacro %} +{%- if not messages %} + {{- raise_exception('No messages provided.') }} +{%- endif %} +{%- if tools and tools is iterable and tools is not mapping %} + {{- '<|im_start|>system\n' }} + {{- "# Tools\n\nYou have access to the following functions:\n\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {%- if content %} + {{- '\n\n' + content }} + {%- endif %} + {%- endif %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" %} + {%- set content = render_content(message.content, false)|trim %} + {%- if not(content.startswith('') and content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if ns.multi_step_tool %} + {{- raise_exception('No user query found in messages.') }} +{%- endif %} +{%- for message in messages %} + {%- set content = render_content(message.content, true)|trim %} + {%- if message.role == "system" %} + {%- if not loop.first %} + {{- raise_exception('System message must be at the beginning.') }} + {%- endif %} + {%- elif message.role == "user" %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- set reasoning_content = reasoning_content|trim %} + {%- if loop.index0 > ns.last_query_index %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {%- if loop.first %} + {%- if content|trim %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n\n' }} + {%- endif %} + {%- else %} + {{- '\n\n\n' }} + {%- endif %} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- raise_exception('Unexpected message role.') }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is true %} + {{- '\n' }} + {%- else %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/trl/chat_templates/qwen3_5_4b_and_above.jinja b/trl/chat_templates/qwen3_5_4b_and_above.jinja new file mode 100644 index 00000000000..a585dec894e --- /dev/null +++ b/trl/chat_templates/qwen3_5_4b_and_above.jinja @@ -0,0 +1,154 @@ +{%- set image_count = namespace(value=0) %} +{%- set video_count = namespace(value=0) %} +{%- macro render_content(content, do_vision_count, is_system_content=false) %} + {%- if content is string %} + {{- content }} + {%- elif content is iterable and content is not mapping %} + {%- for item in content %} + {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain images.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set image_count.value = image_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Picture ' ~ image_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- elif 'video' in item or item.type == 'video' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain videos.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set video_count.value = video_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Video ' ~ video_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|video_pad|><|vision_end|>' }} + {%- elif 'text' in item %} + {{- item.text }} + {%- else %} + {{- raise_exception('Unexpected item type in content.') }} + {%- endif %} + {%- endfor %} + {%- elif content is none or content is undefined %} + {{- '' }} + {%- else %} + {{- raise_exception('Unexpected content type.') }} + {%- endif %} +{%- endmacro %} +{%- if not messages %} + {{- raise_exception('No messages provided.') }} +{%- endif %} +{%- if tools and tools is iterable and tools is not mapping %} + {{- '<|im_start|>system\n' }} + {{- "# Tools\n\nYou have access to the following functions:\n\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {%- if content %} + {{- '\n\n' + content }} + {%- endif %} + {%- endif %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" %} + {%- set content = render_content(message.content, false)|trim %} + {%- if not(content.startswith('') and content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if ns.multi_step_tool %} + {{- raise_exception('No user query found in messages.') }} +{%- endif %} +{%- for message in messages %} + {%- set content = render_content(message.content, true)|trim %} + {%- if message.role == "system" %} + {%- if not loop.first %} + {{- raise_exception('System message must be at the beginning.') }} + {%- endif %} + {%- elif message.role == "user" %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- set reasoning_content = reasoning_content|trim %} + {%- if loop.index0 > ns.last_query_index %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {%- if loop.first %} + {%- if content|trim %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n\n' }} + {%- endif %} + {%- else %} + {{- '\n\n\n' }} + {%- endif %} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- raise_exception('Unexpected message role.') }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/trl/chat_templates/qwen3_training.jinja b/trl/chat_templates/qwen3_training.jinja new file mode 100644 index 00000000000..978e58e3ebf --- /dev/null +++ b/trl/chat_templates/qwen3_training.jinja @@ -0,0 +1,88 @@ +{#- Training variant of the Qwen3 chat template (see qwen3.jinja for the original). + Modifications vs the original: + - {%- if '' in content %} → {%- if '' in content and '' in content %} + Always check for both tags to avoid edge cases where the model generates only one tag. + - Removed the loop.index0 > ns.last_query_index conditional; always include thinking block. + This makes the template prefix-preserving for the [user, assistant] → [user, assistant, tool] transition. +-#} +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content and '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file From 276559da2718c411c4df7f8ca397c7dfcfc6503b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:48:57 +0000 Subject: [PATCH 05/12] tools in dummy --- trl/chat_template_utils.py | 36 ++++++++++++++++++------------------ trl/trainer/grpo_trainer.py | 8 +++++++- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 3e86874816a..615722cac8c 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -16,6 +16,7 @@ from transformers import AddedToken, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer + _CHAT_TEMPLATES_DIR = Path(__file__).parent / "chat_templates" @@ -237,9 +238,9 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ Check whether the chat template preserves prefixes when applied. - A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. - This property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by - comparing tokenizations with and without tool messages appended. + A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. This + property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by comparing + tokenizations with and without tool messages appended. Args: tokenizer (`PreTrainedTokenizer`): @@ -249,22 +250,16 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: `bool`: `True` if the chat template preserves prefixes, `False` otherwise. """ + # Use the same dummy messages as _get_tool_suffix_ids to test the exact property it relies on. + dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] messages1 = [ - {"role": "user", "content": "What is 2 * 3?"}, - { - "role": "assistant", - "content": "", - "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], - }, + {"role": "user", "content": "dummy"}, + {"role": "assistant", "tool_calls": dummy_tool_calls}, ] messages2 = [ - {"role": "user", "content": "What is 2 * 3?"}, - { - "role": "assistant", - "content": "", - "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], - }, - {"role": "tool", "name": "multiply", "content": "6"}, + {"role": "user", "content": "dummy"}, + {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "tool", "name": "dummy", "content": "dummy"}, ] text1 = tokenizer.apply_chat_template(messages1, tokenize=False) @@ -300,7 +295,10 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") >>> messages1 = [ ... {"role": "user", "content": "What is 2 * 3?"}, - ... {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, + ... { + ... "role": "assistant", + ... "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], + ... }, ... ] >>> messages2 = messages1 + [ ... {"role": "tool", "name": "multiply", "content": "6"}, @@ -316,7 +314,9 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template) '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n' - >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template) + >>> tokenizer.apply_chat_template( + ... messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template + ... ) '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n' ``` """ diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0cd9c53aabe..d36e1237eb3 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1406,7 +1406,13 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): def _get_tool_suffix_ids(self, tool_messages): """Get token IDs for tool result formatting by using a minimal dummy conversation.""" - dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}] + dummy_messages = [ + {"role": "user", "content": "dummy"}, + { + "role": "assistant", + "tool_calls": [{"type": "function", "function": {"name": "dummy", "arguments": {}}}], + }, + ] if self._is_vlm: dummy_messages = prepare_multimodal_messages(dummy_messages) prefix_ids = self.processing_class.apply_chat_template( From 673c35d21fab7d533cdf493d5513b243f548e1c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:56:55 +0000 Subject: [PATCH 06/12] Add chat template files to MANIFEST.in --- MANIFEST.in | 2 ++ 1 file changed, 2 insertions(+) diff --git a/MANIFEST.in b/MANIFEST.in index 16cf649ce60..e843c590270 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,6 +3,8 @@ include CONTRIBUTING.md include README.md include trl/accelerate_configs/*.yaml include trl/templates/*.md +include trl/chat_templates/*.jinja +include trl/chat_templates/*.md include trl/skills/**/*.md recursive-exclude * __pycache__ prune tests From 604c476258df074c66729e72ec591ea9db702c4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 20:10:46 +0000 Subject: [PATCH 07/12] Enhance chat template handling to include tool call formatting in messages + async grpo --- tests/test_chat_template_utils.py | 24 ++++++++++++++++--- .../async_grpo/async_rollout_worker.py | 5 ++-- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 2f6cecbb262..b65231fb76c 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -148,13 +148,23 @@ def test_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") tokenizer.chat_template = textwrap.dedent(r""" {%- for message in messages %} + {%- set content = message.content if message.content is string else '' %} {%- if message.role == 'user' %} - {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} + {{- '<|im_start|>user\n' + content + '<|im_end|>\n' }} {%- elif message.role == 'assistant' %} - {{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }} + {{- '<|im_start|>assistant\n' + content }} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '' + tool_call.name + '' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} {%- elif message.role == 'tool' %} - {{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }} + {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} @@ -205,6 +215,14 @@ def test_non_prefix_preserving_template(self): {%- else %} {{- '<|im_start|>' + message.role + '\n' + content }} {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '' + tool_call.name + '' }} + {%- endfor %} + {%- endif %} {{- '<|im_end|>\n' }} {%- elif message.role == "tool" %} {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 3d7350d5a71..d4b64860692 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -555,9 +555,10 @@ async def _generate_one( def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int]: """Get token IDs for tool result formatting by using a minimal dummy conversation.""" + dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ - {"role": "user", "content": ""}, - {"role": "assistant", "content": ""}, + {"role": "user", "content": "dummy"}, + {"role": "assistant", "content": "dummy", "tool_calls": dummy_tool_calls}, ] prefix_ids = self.tokenizer.apply_chat_template( dummy_messages, From 83a7ef66d9d563d3b57c6cf56b8b25d4f17f6d9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 20:18:59 +0000 Subject: [PATCH 08/12] align grpo and async --- .../async_grpo/async_rollout_worker.py | 21 +++++++++++-------- trl/trainer/grpo_trainer.py | 6 ++---- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index d4b64860692..2d0e4b2b6c7 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -558,29 +558,32 @@ def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int] dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "content": "dummy", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "tool_calls": dummy_tool_calls}, ] prefix_ids = self.tokenizer.apply_chat_template( dummy_messages, - return_dict=False, - tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] + add_generation_prompt=False, + tokenize=True, chat_template=self.chat_template, + return_dict=False, **self.chat_template_kwargs, ) full_ids = self.tokenizer.apply_chat_template( dummy_messages + tool_messages, - return_dict=False, - chat_template=self.chat_template, add_generation_prompt=True, - tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] + tokenize=True, + chat_template=self.chat_template, + return_dict=False, **self.chat_template_kwargs, ) # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block. # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to - # EOS (not EOS + newline). - last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id) - prefix_ids = prefix_ids[: last_eos_idx + 1] + # EOS (not EOS + newline). Templates that don't use EOS as end-of-turn (e.g. Gemma uses + # ) skip this trimming. + eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id] + if eos_positions: + prefix_ids = prefix_ids[: eos_positions[-1] + 1] if full_ids[: len(prefix_ids)] != prefix_ids: raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d36e1237eb3..0d172190777 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1406,12 +1406,10 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): def _get_tool_suffix_ids(self, tool_messages): """Get token IDs for tool result formatting by using a minimal dummy conversation.""" + dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - { - "role": "assistant", - "tool_calls": [{"type": "function", "function": {"name": "dummy", "arguments": {}}}], - }, + {"role": "assistant", "tool_calls": dummy_tool_calls}, ] if self._is_vlm: dummy_messages = prepare_multimodal_messages(dummy_messages) From e5d7cdf70a0925dc120ab62c0c8078c28a5f70df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Apr 2026 14:04:54 +0000 Subject: [PATCH 09/12] revert no content --- tests/test_chat_template_utils.py | 9 ++++----- trl/chat_template_utils.py | 5 +++-- trl/experimental/async_grpo/async_rollout_worker.py | 2 +- trl/trainer/grpo_trainer.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index b65231fb76c..1d1840c623a 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -148,12 +148,11 @@ def test_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") tokenizer.chat_template = textwrap.dedent(r""" {%- for message in messages %} - {%- set content = message.content if message.content is string else '' %} {%- if message.role == 'user' %} - {{- '<|im_start|>user\n' + content + '<|im_end|>\n' }} + {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} {%- elif message.role == 'assistant' %} - {{- '<|im_start|>assistant\n' + content }} + {{- '<|im_start|>assistant\n' + message.content }} {%- if message.tool_calls %} {%- for tool_call in message.tool_calls %} {%- if tool_call.function %} @@ -164,7 +163,7 @@ def test_prefix_preserving_template(self): {%- endif %} {{- '<|im_end|>\n' }} {%- elif message.role == 'tool' %} - {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} + {{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} @@ -471,7 +470,7 @@ def test_parse_response_tool_call_without_arguments(self, tokenizer_name): tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}] messages = [ {"role": "user", "content": "Ping the service."}, - {"role": "assistant", "tool_calls": tool_calls}, + {"role": "assistant", "content": "", "tool_calls": tool_calls}, ] prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids text = tokenizer.apply_chat_template(messages).input_ids diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 615722cac8c..fae77c65f15 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -254,11 +254,11 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] messages1 = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, ] messages2 = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, {"role": "tool", "name": "dummy", "content": "dummy"}, ] @@ -297,6 +297,7 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: ... {"role": "user", "content": "What is 2 * 3?"}, ... { ... "role": "assistant", + ... "content": "", ... "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], ... }, ... ] diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 2d0e4b2b6c7..1a63a7d8e90 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -558,7 +558,7 @@ def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int] dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, ] prefix_ids = self.tokenizer.apply_chat_template( dummy_messages, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0d172190777..aa64eec2b5a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1409,7 +1409,7 @@ def _get_tool_suffix_ids(self, tool_messages): dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, ] if self._is_vlm: dummy_messages = prepare_multimodal_messages(dummy_messages) From a618809dbe7460f5ee5092545f91c7e29d641df1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Apr 2026 14:06:33 +0000 Subject: [PATCH 10/12] docstyle ignore --- tests/test_chat_template_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 1d1840c623a..50eb8af4aad 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -146,6 +146,7 @@ def test_add_response_schema(self, tokenizer_name): class TestIsChatTemplatePrefixPreserving: def test_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + # docstyle-ignore tokenizer.chat_template = textwrap.dedent(r""" {%- for message in messages %} From 7b7f5d15eca0b0866c1afee75df66bc637cfd8fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 7 Apr 2026 14:55:40 +0000 Subject: [PATCH 11/12] revert old modif --- tests/test_chat_template_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 155bf77d9b4..2cdd45b85fb 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -490,7 +490,7 @@ def test_parse_response_tool_call_without_arguments(self, tokenizer_name): tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}] messages = [ {"role": "user", "content": "Ping the service."}, - {"role": "assistant", "content": "", "tool_calls": tool_calls}, + {"role": "assistant", "tool_calls": tool_calls}, ] prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids text = tokenizer.apply_chat_template(messages).input_ids From 8e315961e13e9dd68211c33a9ba0f6aae5edb764 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 7 Apr 2026 15:51:34 +0000 Subject: [PATCH 12/12] Add Qwen3-VL tool calling support --- docs/source/grpo_trainer.md | 1 + tests/test_chat_template_utils.py | 21 ++++-- trl/chat_template_utils.py | 5 +- trl/chat_templates/README.md | 4 + trl/chat_templates/qwen3_vl.jinja | 120 ++++++++++++++++++++++++++++++ 5 files changed, 141 insertions(+), 10 deletions(-) create mode 100644 trl/chat_templates/qwen3_vl.jinja diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 5176c3da744..cddf96b5cd0 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -742,6 +742,7 @@ Tested with: - [**Gemma4**](https://huggingface.co/collections/google/gemma-4) — e.g., `google/gemma-4-E2B-it` - [**Qwen3**](https://huggingface.co/collections/Qwen/qwen3) — e.g., `Qwen/Qwen3-0.6B` +- [**Qwen3-VL**](https://huggingface.co/collections/Qwen/qwen3-vl) — e.g., `Qwen/Qwen3-VL-2B-Instruct` - [**Qwen3.5**](https://huggingface.co/collections/Qwen/qwen35) — e.g., `Qwen/Qwen3.5-2B` > [!TIP] diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 2cdd45b85fb..a019f882a70 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -113,7 +113,8 @@ def test_clone_with_sequence_classification_model(self): @pytest.mark.parametrize( "tokenizer_name", [ - pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), + pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"), + pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"), pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"), ], ) @@ -145,7 +146,7 @@ def test_add_response_schema(self, tokenizer_name): class TestIsChatTemplatePrefixPreserving: def test_prefix_preserving_template(self): - tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForCausalLM") # docstyle-ignore tokenizer.chat_template = textwrap.dedent(r""" {%- for message in messages %} @@ -175,7 +176,7 @@ def test_prefix_preserving_template(self): assert is_chat_template_prefix_preserving(tokenizer) is True def test_non_prefix_preserving_template(self): - tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForCausalLM") # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part (even # empty) is only present for last assistant message, which makes it non-prefix-preserving: appending a tool # message changes the earlier output. @@ -240,7 +241,7 @@ def test_non_prefix_preserving_template(self): @pytest.mark.parametrize( "tokenizer_name", [ - pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), + pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"), ], ) class TestGetTrainingChatTemplate: @@ -396,7 +397,8 @@ def test_behavior_unchanged_generation_prompt_with_enable_thinking_false(self, t @pytest.mark.parametrize( "tokenizer_name", [ - pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), + pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"), + pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"), pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"), pytest.param( "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", @@ -430,8 +432,11 @@ def test_parse_response(self, tokenizer_name): assert parsed == messages[-1] def test_parse_response_with_reasoning_content(self, tokenizer_name): - if tokenizer_name == "trl-internal-testing/tiny-Gemma4ForConditionalGeneration": - pytest.skip("Gemma4 doesn't support inline reasoning_content.") + if tokenizer_name in [ + "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", + "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", + ]: + pytest.skip("This tokenizer doesn't support inline reasoning_content.") tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) if getattr(tokenizer, "response_schema", None) is None: tokenizer = add_response_schema(tokenizer) @@ -517,7 +522,7 @@ def test_parse_response_multiple_tool_calls(self, tokenizer_name): assert parsed == messages[-1] def test_parse_response_malformed_tool_call(self, tokenizer_name): - if tokenizer_name != "trl-internal-testing/tiny-Qwen3MoeForSequenceClassification": + if tokenizer_name != "trl-internal-testing/tiny-Qwen3MoeForCausalLM": pytest.skip("For simplicity, we only test the malformed tool call case on one tokenizer.") tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) if getattr(tokenizer, "response_schema", None) is None: diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index c85a7b45a04..b6897cee232 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -185,8 +185,9 @@ def clone_chat_template( qwen3_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3.jinja").read_text() -qwen3_5_chat_template_2b_and_below = (_CHAT_TEMPLATES_DIR / "qwen3_5_2b_and_below.jinja").read_text() +qwen3_vl_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3_vl.jinja").read_text() +qwen3_5_chat_template_2b_and_below = (_CHAT_TEMPLATES_DIR / "qwen3_5_2b_and_below.jinja").read_text() qwen3_5_chat_template_4b_and_above = (_CHAT_TEMPLATES_DIR / "qwen3_5_4b_and_above.jinja").read_text() @@ -220,7 +221,7 @@ def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: {'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]} ``` """ - if tokenizer.chat_template == qwen3_chat_template: + if tokenizer.chat_template in [qwen3_chat_template, qwen3_vl_chat_template]: tokenizer.response_schema = qwen3_schema return tokenizer if tokenizer.chat_template in [qwen3_5_chat_template_2b_and_below, qwen3_5_chat_template_4b_and_above]: diff --git a/trl/chat_templates/README.md b/trl/chat_templates/README.md index 5eb01cc4aca..899d8022fe6 100644 --- a/trl/chat_templates/README.md +++ b/trl/chat_templates/README.md @@ -15,6 +15,10 @@ Used for identity comparison only. Original Qwen3 chat template. +### `qwen3_vl.jinja` + +Original Qwen3-VL chat template. Unlike text-only Qwen3, this template is already prefix-preserving (no conditional thinking blocks), so no training patch is needed. + ### `qwen3_5_2b_and_below.jinja` / `qwen3_5_4b_and_above.jinja` Original Qwen3.5 chat templates. diff --git a/trl/chat_templates/qwen3_vl.jinja b/trl/chat_templates/qwen3_vl.jinja new file mode 100644 index 00000000000..124386803f1 --- /dev/null +++ b/trl/chat_templates/qwen3_vl.jinja @@ -0,0 +1,120 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {%- if messages[0].content is string %} + {{- messages[0].content }} + {%- else %} + {%- for content in messages[0].content %} + {%- if 'text' in content %} + {{- content.text }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].content is string %} + {{- messages[0].content }} + {%- else %} + {%- for content in messages[0].content %} + {%- if 'text' in content %} + {{- content.text }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set image_count = namespace(value=0) %} +{%- set video_count = namespace(value=0) %} +{%- for message in messages %} + {%- if message.role == "user" %} + {{- '<|im_start|>' + message.role + '\n' }} + {%- if message.content is string %} + {{- message.content }} + {%- else %} + {%- for content in message.content %} + {%- if content.type == 'image' or 'image' in content or 'image_url' in content %} + {%- set image_count.value = image_count.value + 1 %} + {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%} + <|vision_start|><|image_pad|><|vision_end|> + {%- elif content.type == 'video' or 'video' in content %} + {%- set video_count.value = video_count.value + 1 %} + {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%} + <|vision_start|><|video_pad|><|vision_end|> + {%- elif 'text' in content %} + {{- content.text }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' }} + {%- if message.content is string %} + {{- message.content }} + {%- else %} + {%- for content_item in message.content %} + {%- if 'text' in content_item %} + {{- content_item.text }} + {%- endif %} + {%- endfor %} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and message.content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {%- if message.content is string %} + {{- message.content }} + {%- else %} + {%- for content in message.content %} + {%- if content.type == 'image' or 'image' in content or 'image_url' in content %} + {%- set image_count.value = image_count.value + 1 %} + {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%} + <|vision_start|><|image_pad|><|vision_end|> + {%- elif content.type == 'video' or 'video' in content %} + {%- set video_count.value = video_count.value + 1 %} + {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%} + <|vision_start|><|video_pad|><|vision_end|> + {%- elif 'text' in content %} + {{- content.text }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %}