From 4b3aa51d70e6fa8dc65aa1d5ce94509e5e535e17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 18:44:21 +0000 Subject: [PATCH 01/17] Narrow prefix-preserving check to the actual requirement --- tests/test_chat_template_utils.py | 17 ++++----- trl/chat_template_utils.py | 57 +++++++++---------------------- 2 files changed, 23 insertions(+), 51 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 6459e3d8fef..2f6cecbb262 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -153,6 +153,8 @@ def test_prefix_preserving_template(self): {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} {%- elif message.role == 'assistant' %} {{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }} + {%- elif message.role == 'tool' %} + {{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} @@ -164,8 +166,9 @@ def test_prefix_preserving_template(self): def test_non_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") - # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part is - # only present for last assistant message, which makes it non-prefix-preserving. + # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part (even + # empty) is only present for last assistant message, which makes it non-prefix-preserving: appending a tool + # message changes the earlier output. # docstyle-ignore tokenizer.chat_template = textwrap.dedent(r""" {%- if messages[0].role == 'system' %} @@ -203,6 +206,8 @@ def test_non_prefix_preserving_template(self): {{- '<|im_start|>' + message.role + '\n' + content }} {%- endif %} {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} {%- if add_generation_prompt %} @@ -218,14 +223,6 @@ def test_non_prefix_preserving_template(self): "tokenizer_name", [ pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), - pytest.param( - "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", - id="qwen35", - marks=pytest.mark.skipif( - Version(transformers.__version__) < Version("5.0.0"), - reason="Qwen3.5 tokenizer requires transformers>=5.0.0", - ), - ), ], ) class TestGetTrainingChatTemplate: diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 90be3fa3b57..eeb3f74e794 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -630,6 +630,10 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ Check whether the chat template preserves prefixes when applied. + A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. + This property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by + comparing tokenizations with and without tool messages appended. + Args: tokenizer (`PreTrainedTokenizer`): Tokenizer instance to check. @@ -640,22 +644,18 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ messages1 = [ {"role": "user", "content": "What color is the sky?"}, - ] - messages2 = [ - {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ] - messages3 = [ + messages2 = [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, - {"role": "user", "content": "And at night?"}, + {"role": "tool", "name": "say_hi", "content": "Hi!"}, ] - text1 = tokenizer.apply_chat_template(messages1, tokenize=False, add_generation_prompt=True) - text2 = tokenizer.apply_chat_template(messages2, tokenize=False) - text3 = tokenizer.apply_chat_template(messages3, tokenize=False) + text1 = tokenizer.apply_chat_template(messages1, tokenize=False) + text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) - return text2.startswith(text1) and text3.startswith(text2) + return text2.startswith(text1) # Modifications: @@ -749,33 +749,12 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: {%- endif %}""" -# Modifications: -# - {%- if '' in content %} -# + {%- if '' in content and '' in content %} -# Always check for both tags to avoid edge cases where the model generates only one tag, which would otherwise be parsed incorrectly -# - {{- '<|im_start|>' + message.role + '\n' + content }} -# + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} -# Always include thinking block during training. It's important to have a prefix-preserving template. -def _patch_qwen3_5_training_template(template: str) -> str: - return template.replace( - "{%- if '' in content %}", - "{%- if '' in content and '' in content %}", - ).replace( - "{{- '<|im_start|>' + message.role + '\\n' + content }}", - "{{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content + '\\n\\n\\n' + content }}", - ) - - -qwen3_5_training_chat_template_2b_and_below = _patch_qwen3_5_training_template(qwen3_5_chat_template_2b_and_below) -qwen3_5_training_chat_template_4b_and_above = _patch_qwen3_5_training_template(qwen3_5_chat_template_4b_and_above) - - def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: r""" Get a prefix-preserving chat template for training, if needed. - If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 and - Qwen3.5 supported). Otherwise, returns `None`. + If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 + supported). Otherwise, returns `None`. Args: tokenizer (`PreTrainedTokenizer`): @@ -799,21 +778,21 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> messages2 = [ ... {"role": "user", "content": "What color is the sky?"}, ... {"role": "assistant", "content": "It is blue."}, - ... {"role": "user", "content": "And at night?"}, + ... {"role": "tool", "name": "say_hi", "content": "Hi!"}, ... ] >>> tokenizer.apply_chat_template(messages1, tokenize=False) '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' - >>> tokenizer.apply_chat_template(messages2, tokenize=False) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' + >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) + '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\n\nHi!\n<|im_end|>\n<|im_start|>assistant\n' >>> # ^ think tags missing >>> chat_template = get_training_chat_template(tokenizer) >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template) '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' - >>> tokenizer.apply_chat_template(messages2, tokenize=False, chat_template=chat_template) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' + >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template) + '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\n\nHi!\n<|im_end|>\n<|im_start|>assistant\n' ``` """ # First check if patching is needed @@ -822,10 +801,6 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: if tokenizer.chat_template == qwen3_chat_template: return qwen3_training_chat_template - if tokenizer.chat_template == qwen3_5_chat_template_2b_and_below: - return qwen3_5_training_chat_template_2b_and_below - if tokenizer.chat_template == qwen3_5_chat_template_4b_and_above: - return qwen3_5_training_chat_template_4b_and_above else: raise ValueError( "The tokenizer's chat template is not prefix-preserving and patching is not supported for this template. " From 730070bcf14f58ae60c9646ec379441d4d112c01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:00:02 +0000 Subject: [PATCH 02/17] Update chat template examples to use multiplication function calls --- trl/chat_template_utils.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index eeb3f74e794..21bbe598685 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -643,13 +643,13 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: `True` if the chat template preserves prefixes, `False` otherwise. """ messages1 = [ - {"role": "user", "content": "What color is the sky?"}, - {"role": "assistant", "content": "It is blue."}, + {"role": "user", "content": "What is 2 * 3?"}, + {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, ] messages2 = [ - {"role": "user", "content": "What color is the sky?"}, - {"role": "assistant", "content": "It is blue."}, - {"role": "tool", "name": "say_hi", "content": "Hi!"}, + {"role": "user", "content": "What is 2 * 3?"}, + {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, + {"role": "tool", "name": "multiply", "content": "6"}, ] text1 = tokenizer.apply_chat_template(messages1, tokenize=False) @@ -772,27 +772,25 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") >>> messages1 = [ - ... {"role": "user", "content": "What color is the sky?"}, - ... {"role": "assistant", "content": "It is blue."}, + ... {"role": "user", "content": "What is 2 * 3?"}, + ... {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, ... ] - >>> messages2 = [ - ... {"role": "user", "content": "What color is the sky?"}, - ... {"role": "assistant", "content": "It is blue."}, - ... {"role": "tool", "name": "say_hi", "content": "Hi!"}, + >>> messages2 = messages1 + [ + ... {"role": "tool", "name": "multiply", "content": "6"}, ... ] >>> tokenizer.apply_chat_template(messages1, tokenize=False) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n' >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\n\nHi!\n<|im_end|>\n<|im_start|>assistant\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n' - >>> # ^ think tags missing + >>> # ^ think tags missing >>> chat_template = get_training_chat_template(tokenizer) >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n' >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\n\nHi!\n<|im_end|>\n<|im_start|>assistant\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n' ``` """ # First check if patching is needed From 4622d77234bf46d010b782fcebab1b2ed2d2e32a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:01:42 +0000 Subject: [PATCH 03/17] style --- trl/chat_template_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 21bbe598685..6b9d3c24e2b 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -644,11 +644,19 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ messages1 = [ {"role": "user", "content": "What is 2 * 3?"}, - {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], + }, ] messages2 = [ {"role": "user", "content": "What is 2 * 3?"}, - {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], + }, {"role": "tool", "name": "multiply", "content": "6"}, ] From 08d4c518ec29b2c9d3e0f2fff5eae90ca9aebdf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:33:46 +0000 Subject: [PATCH 04/17] Move chat templates from inline strings to `.jinja` files --- trl/chat_template_utils.py | 503 +----------------- trl/chat_templates/README.md | 50 ++ trl/chat_templates/qwen3.jinja | 89 ++++ trl/chat_templates/qwen3_5_2b_and_below.jinja | 154 ++++++ trl/chat_templates/qwen3_5_4b_and_above.jinja | 154 ++++++ trl/chat_templates/qwen3_training.jinja | 88 +++ 6 files changed, 546 insertions(+), 492 deletions(-) create mode 100644 trl/chat_templates/README.md create mode 100644 trl/chat_templates/qwen3.jinja create mode 100644 trl/chat_templates/qwen3_5_2b_and_below.jinja create mode 100644 trl/chat_templates/qwen3_5_4b_and_above.jinja create mode 100644 trl/chat_templates/qwen3_training.jinja diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 6b9d3c24e2b..3e86874816a 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + from transformers import AddedToken, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer +_CHAT_TEMPLATES_DIR = Path(__file__).parent / "chat_templates" + def clone_chat_template( model: PreTrainedModel, @@ -178,409 +182,12 @@ def clone_chat_template( }, } -# docstyle-ignore -qwen3_chat_template = r"""{%- if tools %} - {{- '<|im_start|>system\n' }} - {%- if messages[0].role == 'system' %} - {{- messages[0].content + '\n\n' }} - {%- endif %} - {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} - {%- for tool in tools %} - {{- "\n" }} - {{- tool | tojson }} - {%- endfor %} - {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} -{%- else %} - {%- if messages[0].role == 'system' %} - {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} - {%- endif %} -{%- endif %} -{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} -{%- for message in messages[::-1] %} - {%- set index = (messages|length - 1) - loop.index0 %} - {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} - {%- set ns.multi_step_tool = false %} - {%- set ns.last_query_index = index %} - {%- endif %} -{%- endfor %} -{%- for message in messages %} - {%- if message.content is string %} - {%- set content = message.content %} - {%- else %} - {%- set content = '' %} - {%- endif %} - {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} - {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} - {%- elif message.role == "assistant" %} - {%- set reasoning_content = '' %} - {%- if message.reasoning_content is string %} - {%- set reasoning_content = message.reasoning_content %} - {%- else %} - {%- if '' in content %} - {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} - {%- set content = content.split('')[-1].lstrip('\n') %} - {%- endif %} - {%- endif %} - {%- if loop.index0 > ns.last_query_index %} - {%- if loop.last or (not loop.last and reasoning_content) %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} - {%- else %} - {{- '<|im_start|>' + message.role + '\n' + content }} - {%- endif %} - {%- else %} - {{- '<|im_start|>' + message.role + '\n' + content }} - {%- endif %} - {%- if message.tool_calls %} - {%- for tool_call in message.tool_calls %} - {%- if (loop.first and content) or (not loop.first) %} - {{- '\n' }} - {%- endif %} - {%- if tool_call.function %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {{- '\n{"name": "' }} - {{- tool_call.name }} - {{- '", "arguments": ' }} - {%- if tool_call.arguments is string %} - {{- tool_call.arguments }} - {%- else %} - {{- tool_call.arguments | tojson }} - {%- endif %} - {{- '}\n' }} - {%- endfor %} - {%- endif %} - {{- '<|im_end|>\n' }} - {%- elif message.role == "tool" %} - {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} - {{- '<|im_start|>user' }} - {%- endif %} - {{- '\n\n' }} - {{- content }} - {{- '\n' }} - {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} - {{- '<|im_end|>\n' }} - {%- endif %} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|im_start|>assistant\n' }} - {%- if enable_thinking is defined and enable_thinking is false %} - {{- '\n\n\n\n' }} - {%- endif %} -{%- endif %}""" - -# docstyle-ignore -qwen3_5_chat_template_2b_and_below = r"""{%- set image_count = namespace(value=0) %} -{%- set video_count = namespace(value=0) %} -{%- macro render_content(content, do_vision_count, is_system_content=false) %} - {%- if content is string %} - {{- content }} - {%- elif content is iterable and content is not mapping %} - {%- for item in content %} - {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} - {%- if is_system_content %} - {{- raise_exception('System message cannot contain images.') }} - {%- endif %} - {%- if do_vision_count %} - {%- set image_count.value = image_count.value + 1 %} - {%- endif %} - {%- if add_vision_id %} - {{- 'Picture ' ~ image_count.value ~ ': ' }} - {%- endif %} - {{- '<|vision_start|><|image_pad|><|vision_end|>' }} - {%- elif 'video' in item or item.type == 'video' %} - {%- if is_system_content %} - {{- raise_exception('System message cannot contain videos.') }} - {%- endif %} - {%- if do_vision_count %} - {%- set video_count.value = video_count.value + 1 %} - {%- endif %} - {%- if add_vision_id %} - {{- 'Video ' ~ video_count.value ~ ': ' }} - {%- endif %} - {{- '<|vision_start|><|video_pad|><|vision_end|>' }} - {%- elif 'text' in item %} - {{- item.text }} - {%- else %} - {{- raise_exception('Unexpected item type in content.') }} - {%- endif %} - {%- endfor %} - {%- elif content is none or content is undefined %} - {{- '' }} - {%- else %} - {{- raise_exception('Unexpected content type.') }} - {%- endif %} -{%- endmacro %} -{%- if not messages %} - {{- raise_exception('No messages provided.') }} -{%- endif %} -{%- if tools and tools is iterable and tools is not mapping %} - {{- '<|im_start|>system\n' }} - {{- "# Tools\n\nYou have access to the following functions:\n\n" }} - {%- for tool in tools %} - {{- "\n" }} - {{- tool | tojson }} - {%- endfor %} - {{- "\n" }} - {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} - {%- if messages[0].role == 'system' %} - {%- set content = render_content(messages[0].content, false, true)|trim %} - {%- if content %} - {{- '\n\n' + content }} - {%- endif %} - {%- endif %} - {{- '<|im_end|>\n' }} -{%- else %} - {%- if messages[0].role == 'system' %} - {%- set content = render_content(messages[0].content, false, true)|trim %} - {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} - {%- endif %} -{%- endif %} -{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} -{%- for message in messages[::-1] %} - {%- set index = (messages|length - 1) - loop.index0 %} - {%- if ns.multi_step_tool and message.role == "user" %} - {%- set content = render_content(message.content, false)|trim %} - {%- if not(content.startswith('') and content.endswith('')) %} - {%- set ns.multi_step_tool = false %} - {%- set ns.last_query_index = index %} - {%- endif %} - {%- endif %} -{%- endfor %} -{%- if ns.multi_step_tool %} - {{- raise_exception('No user query found in messages.') }} -{%- endif %} -{%- for message in messages %} - {%- set content = render_content(message.content, true)|trim %} - {%- if message.role == "system" %} - {%- if not loop.first %} - {{- raise_exception('System message must be at the beginning.') }} - {%- endif %} - {%- elif message.role == "user" %} - {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} - {%- elif message.role == "assistant" %} - {%- set reasoning_content = '' %} - {%- if message.reasoning_content is string %} - {%- set reasoning_content = message.reasoning_content %} - {%- else %} - {%- if '' in content %} - {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} - {%- set content = content.split('')[-1].lstrip('\n') %} - {%- endif %} - {%- endif %} - {%- set reasoning_content = reasoning_content|trim %} - {%- if loop.index0 > ns.last_query_index %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} - {%- else %} - {{- '<|im_start|>' + message.role + '\n' + content }} - {%- endif %} - {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} - {%- for tool_call in message.tool_calls %} - {%- if tool_call.function is defined %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {%- if loop.first %} - {%- if content|trim %} - {{- '\n\n\n\n' }} - {%- else %} - {{- '\n\n' }} - {%- endif %} - {%- else %} - {{- '\n\n\n' }} - {%- endif %} - {%- if tool_call.arguments is defined %} - {%- for args_name, args_value in tool_call.arguments|items %} - {{- '\n' }} - {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} - {{- args_value }} - {{- '\n\n' }} - {%- endfor %} - {%- endif %} - {{- '\n' }} - {%- endfor %} - {%- endif %} - {{- '<|im_end|>\n' }} - {%- elif message.role == "tool" %} - {%- if loop.previtem and loop.previtem.role != "tool" %} - {{- '<|im_start|>user' }} - {%- endif %} - {{- '\n\n' }} - {{- content }} - {{- '\n' }} - {%- if not loop.last and loop.nextitem.role != "tool" %} - {{- '<|im_end|>\n' }} - {%- elif loop.last %} - {{- '<|im_end|>\n' }} - {%- endif %} - {%- else %} - {{- raise_exception('Unexpected message role.') }} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|im_start|>assistant\n' }} - {%- if enable_thinking is defined and enable_thinking is true %} - {{- '\n' }} - {%- else %} - {{- '\n\n\n\n' }} - {%- endif %} -{%- endif %}""" - - -# docstyle-ignore -qwen3_5_chat_template_4b_and_above = r"""{%- set image_count = namespace(value=0) %} -{%- set video_count = namespace(value=0) %} -{%- macro render_content(content, do_vision_count, is_system_content=false) %} - {%- if content is string %} - {{- content }} - {%- elif content is iterable and content is not mapping %} - {%- for item in content %} - {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} - {%- if is_system_content %} - {{- raise_exception('System message cannot contain images.') }} - {%- endif %} - {%- if do_vision_count %} - {%- set image_count.value = image_count.value + 1 %} - {%- endif %} - {%- if add_vision_id %} - {{- 'Picture ' ~ image_count.value ~ ': ' }} - {%- endif %} - {{- '<|vision_start|><|image_pad|><|vision_end|>' }} - {%- elif 'video' in item or item.type == 'video' %} - {%- if is_system_content %} - {{- raise_exception('System message cannot contain videos.') }} - {%- endif %} - {%- if do_vision_count %} - {%- set video_count.value = video_count.value + 1 %} - {%- endif %} - {%- if add_vision_id %} - {{- 'Video ' ~ video_count.value ~ ': ' }} - {%- endif %} - {{- '<|vision_start|><|video_pad|><|vision_end|>' }} - {%- elif 'text' in item %} - {{- item.text }} - {%- else %} - {{- raise_exception('Unexpected item type in content.') }} - {%- endif %} - {%- endfor %} - {%- elif content is none or content is undefined %} - {{- '' }} - {%- else %} - {{- raise_exception('Unexpected content type.') }} - {%- endif %} -{%- endmacro %} -{%- if not messages %} - {{- raise_exception('No messages provided.') }} -{%- endif %} -{%- if tools and tools is iterable and tools is not mapping %} - {{- '<|im_start|>system\n' }} - {{- "# Tools\n\nYou have access to the following functions:\n\n" }} - {%- for tool in tools %} - {{- "\n" }} - {{- tool | tojson }} - {%- endfor %} - {{- "\n" }} - {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} - {%- if messages[0].role == 'system' %} - {%- set content = render_content(messages[0].content, false, true)|trim %} - {%- if content %} - {{- '\n\n' + content }} - {%- endif %} - {%- endif %} - {{- '<|im_end|>\n' }} -{%- else %} - {%- if messages[0].role == 'system' %} - {%- set content = render_content(messages[0].content, false, true)|trim %} - {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} - {%- endif %} -{%- endif %} -{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} -{%- for message in messages[::-1] %} - {%- set index = (messages|length - 1) - loop.index0 %} - {%- if ns.multi_step_tool and message.role == "user" %} - {%- set content = render_content(message.content, false)|trim %} - {%- if not(content.startswith('') and content.endswith('')) %} - {%- set ns.multi_step_tool = false %} - {%- set ns.last_query_index = index %} - {%- endif %} - {%- endif %} -{%- endfor %} -{%- if ns.multi_step_tool %} - {{- raise_exception('No user query found in messages.') }} -{%- endif %} -{%- for message in messages %} - {%- set content = render_content(message.content, true)|trim %} - {%- if message.role == "system" %} - {%- if not loop.first %} - {{- raise_exception('System message must be at the beginning.') }} - {%- endif %} - {%- elif message.role == "user" %} - {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} - {%- elif message.role == "assistant" %} - {%- set reasoning_content = '' %} - {%- if message.reasoning_content is string %} - {%- set reasoning_content = message.reasoning_content %} - {%- else %} - {%- if '' in content %} - {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} - {%- set content = content.split('')[-1].lstrip('\n') %} - {%- endif %} - {%- endif %} - {%- set reasoning_content = reasoning_content|trim %} - {%- if loop.index0 > ns.last_query_index %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} - {%- else %} - {{- '<|im_start|>' + message.role + '\n' + content }} - {%- endif %} - {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} - {%- for tool_call in message.tool_calls %} - {%- if tool_call.function is defined %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {%- if loop.first %} - {%- if content|trim %} - {{- '\n\n\n\n' }} - {%- else %} - {{- '\n\n' }} - {%- endif %} - {%- else %} - {{- '\n\n\n' }} - {%- endif %} - {%- if tool_call.arguments is defined %} - {%- for args_name, args_value in tool_call.arguments|items %} - {{- '\n' }} - {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} - {{- args_value }} - {{- '\n\n' }} - {%- endfor %} - {%- endif %} - {{- '\n' }} - {%- endfor %} - {%- endif %} - {{- '<|im_end|>\n' }} - {%- elif message.role == "tool" %} - {%- if loop.previtem and loop.previtem.role != "tool" %} - {{- '<|im_start|>user' }} - {%- endif %} - {{- '\n\n' }} - {{- content }} - {{- '\n' }} - {%- if not loop.last and loop.nextitem.role != "tool" %} - {{- '<|im_end|>\n' }} - {%- elif loop.last %} - {{- '<|im_end|>\n' }} - {%- endif %} - {%- else %} - {{- raise_exception('Unexpected message role.') }} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|im_start|>assistant\n' }} - {%- if enable_thinking is defined and enable_thinking is false %} - {{- '\n\n\n\n' }} - {%- else %} - {{- '\n' }} - {%- endif %} -{%- endif %}""" +qwen3_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3.jinja").read_text() + +qwen3_5_chat_template_2b_and_below = (_CHAT_TEMPLATES_DIR / "qwen3_5_2b_and_below.jinja").read_text() + + +qwen3_5_chat_template_4b_and_above = (_CHAT_TEMPLATES_DIR / "qwen3_5_4b_and_above.jinja").read_text() def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: @@ -666,95 +273,7 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: return text2.startswith(text1) -# Modifications: -# - {%- if '' in content %} -# + {%- if '' in content and '' in content %} -# Always check for both tags to avoid edge cases where the model generates only one tag, which would otherwise be parsed incorrectly -# - {%- if loop.index0 > ns.last_query_index %} ... {%- endif %} -# + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} -# Always include thinking block during training. It's important to have a prefix-preserving template. -# docstyle-ignore -qwen3_training_chat_template = r"""{%- if tools %} - {{- '<|im_start|>system\n' }} - {%- if messages[0].role == 'system' %} - {{- messages[0].content + '\n\n' }} - {%- endif %} - {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} - {%- for tool in tools %} - {{- "\n" }} - {{- tool | tojson }} - {%- endfor %} - {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} -{%- else %} - {%- if messages[0].role == 'system' %} - {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} - {%- endif %} -{%- endif %} -{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} -{%- for message in messages[::-1] %} - {%- set index = (messages|length - 1) - loop.index0 %} - {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} - {%- set ns.multi_step_tool = false %} - {%- set ns.last_query_index = index %} - {%- endif %} -{%- endfor %} -{%- for message in messages %} - {%- if message.content is string %} - {%- set content = message.content %} - {%- else %} - {%- set content = '' %} - {%- endif %} - {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} - {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} - {%- elif message.role == "assistant" %} - {%- set reasoning_content = '' %} - {%- if message.reasoning_content is string %} - {%- set reasoning_content = message.reasoning_content %} - {%- else %} - {%- if '' in content and '' in content %} - {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} - {%- set content = content.split('')[-1].lstrip('\n') %} - {%- endif %} - {%- endif %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} - {%- if message.tool_calls %} - {%- for tool_call in message.tool_calls %} - {%- if (loop.first and content) or (not loop.first) %} - {{- '\n' }} - {%- endif %} - {%- if tool_call.function %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {{- '\n{"name": "' }} - {{- tool_call.name }} - {{- '", "arguments": ' }} - {%- if tool_call.arguments is string %} - {{- tool_call.arguments }} - {%- else %} - {{- tool_call.arguments | tojson }} - {%- endif %} - {{- '}\n' }} - {%- endfor %} - {%- endif %} - {{- '<|im_end|>\n' }} - {%- elif message.role == "tool" %} - {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} - {{- '<|im_start|>user' }} - {%- endif %} - {{- '\n\n' }} - {{- content }} - {{- '\n' }} - {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} - {{- '<|im_end|>\n' }} - {%- endif %} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|im_start|>assistant\n' }} - {%- if enable_thinking is defined and enable_thinking is false %} - {{- '\n\n\n\n' }} - {%- endif %} -{%- endif %}""" +qwen3_training_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3_training.jinja").read_text() def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: diff --git a/trl/chat_templates/README.md b/trl/chat_templates/README.md new file mode 100644 index 00000000000..5eb01cc4aca --- /dev/null +++ b/trl/chat_templates/README.md @@ -0,0 +1,50 @@ +# Chat Templates + +Jinja2 chat templates stored here serve two purposes: + +1. **Identity comparison**: detecting which model is being used (by comparing `tokenizer.chat_template` against known templates) to add the appropriate response schema (`add_response_schema`) or swap in a training template (`get_training_chat_template`). +2. **Training patches**: modified templates that fix prefix-preservation issues for the GRPO tool call loop. + +**Why prefix-preserving?** The GRPO tool call loop extracts tool response formatting tokens by comparing tokenizations with and without tool messages appended (`_get_tool_suffix_ids`). This requires the chat template to be *prefix-preserving*: appending messages must not change how earlier messages are rendered. + +## Original templates + +Used for identity comparison only. + +### `qwen3.jinja` + +Original Qwen3 chat template. + +### `qwen3_5_2b_and_below.jinja` / `qwen3_5_4b_and_above.jinja` + +Original Qwen3.5 chat templates. + +## Training templates + +Patched templates that fix prefix-preservation issues. Swapped in at init when tools are enabled. + +### `qwen3_training.jinja` + +Patched Qwen3 template that always includes thinking blocks, making it prefix-preserving. Diff vs `qwen3.jinja`: + +Require both `` and `` to be present before parsing, to avoid incorrect splitting when the model generates only one tag: + +```diff +- {%- if '' in content %} ++ {%- if '' in content and '' in content %} +``` + +Always include the thinking block regardless of message position. The original conditionally omits it based on `loop.last`, which changes the assistant rendering when a tool message is appended — breaking prefix-preservation: + +```diff +- {%- if loop.index0 > ns.last_query_index %} +- {%- if loop.last or (not loop.last and reasoning_content) %} +- {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} +- {%- else %} +- {{- '<|im_start|>' + message.role + '\n' + content }} +- {%- endif %} +- {%- else %} +- {{- '<|im_start|>' + message.role + '\n' + content }} +- {%- endif %} ++ {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} +``` diff --git a/trl/chat_templates/qwen3.jinja b/trl/chat_templates/qwen3.jinja new file mode 100644 index 00000000000..01be9b307da --- /dev/null +++ b/trl/chat_templates/qwen3.jinja @@ -0,0 +1,89 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/trl/chat_templates/qwen3_5_2b_and_below.jinja b/trl/chat_templates/qwen3_5_2b_and_below.jinja new file mode 100644 index 00000000000..0ef09f214ea --- /dev/null +++ b/trl/chat_templates/qwen3_5_2b_and_below.jinja @@ -0,0 +1,154 @@ +{%- set image_count = namespace(value=0) %} +{%- set video_count = namespace(value=0) %} +{%- macro render_content(content, do_vision_count, is_system_content=false) %} + {%- if content is string %} + {{- content }} + {%- elif content is iterable and content is not mapping %} + {%- for item in content %} + {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain images.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set image_count.value = image_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Picture ' ~ image_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- elif 'video' in item or item.type == 'video' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain videos.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set video_count.value = video_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Video ' ~ video_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|video_pad|><|vision_end|>' }} + {%- elif 'text' in item %} + {{- item.text }} + {%- else %} + {{- raise_exception('Unexpected item type in content.') }} + {%- endif %} + {%- endfor %} + {%- elif content is none or content is undefined %} + {{- '' }} + {%- else %} + {{- raise_exception('Unexpected content type.') }} + {%- endif %} +{%- endmacro %} +{%- if not messages %} + {{- raise_exception('No messages provided.') }} +{%- endif %} +{%- if tools and tools is iterable and tools is not mapping %} + {{- '<|im_start|>system\n' }} + {{- "# Tools\n\nYou have access to the following functions:\n\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {%- if content %} + {{- '\n\n' + content }} + {%- endif %} + {%- endif %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" %} + {%- set content = render_content(message.content, false)|trim %} + {%- if not(content.startswith('') and content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if ns.multi_step_tool %} + {{- raise_exception('No user query found in messages.') }} +{%- endif %} +{%- for message in messages %} + {%- set content = render_content(message.content, true)|trim %} + {%- if message.role == "system" %} + {%- if not loop.first %} + {{- raise_exception('System message must be at the beginning.') }} + {%- endif %} + {%- elif message.role == "user" %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- set reasoning_content = reasoning_content|trim %} + {%- if loop.index0 > ns.last_query_index %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {%- if loop.first %} + {%- if content|trim %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n\n' }} + {%- endif %} + {%- else %} + {{- '\n\n\n' }} + {%- endif %} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- raise_exception('Unexpected message role.') }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is true %} + {{- '\n' }} + {%- else %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/trl/chat_templates/qwen3_5_4b_and_above.jinja b/trl/chat_templates/qwen3_5_4b_and_above.jinja new file mode 100644 index 00000000000..a585dec894e --- /dev/null +++ b/trl/chat_templates/qwen3_5_4b_and_above.jinja @@ -0,0 +1,154 @@ +{%- set image_count = namespace(value=0) %} +{%- set video_count = namespace(value=0) %} +{%- macro render_content(content, do_vision_count, is_system_content=false) %} + {%- if content is string %} + {{- content }} + {%- elif content is iterable and content is not mapping %} + {%- for item in content %} + {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain images.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set image_count.value = image_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Picture ' ~ image_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- elif 'video' in item or item.type == 'video' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain videos.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set video_count.value = video_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Video ' ~ video_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|video_pad|><|vision_end|>' }} + {%- elif 'text' in item %} + {{- item.text }} + {%- else %} + {{- raise_exception('Unexpected item type in content.') }} + {%- endif %} + {%- endfor %} + {%- elif content is none or content is undefined %} + {{- '' }} + {%- else %} + {{- raise_exception('Unexpected content type.') }} + {%- endif %} +{%- endmacro %} +{%- if not messages %} + {{- raise_exception('No messages provided.') }} +{%- endif %} +{%- if tools and tools is iterable and tools is not mapping %} + {{- '<|im_start|>system\n' }} + {{- "# Tools\n\nYou have access to the following functions:\n\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {%- if content %} + {{- '\n\n' + content }} + {%- endif %} + {%- endif %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" %} + {%- set content = render_content(message.content, false)|trim %} + {%- if not(content.startswith('') and content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if ns.multi_step_tool %} + {{- raise_exception('No user query found in messages.') }} +{%- endif %} +{%- for message in messages %} + {%- set content = render_content(message.content, true)|trim %} + {%- if message.role == "system" %} + {%- if not loop.first %} + {{- raise_exception('System message must be at the beginning.') }} + {%- endif %} + {%- elif message.role == "user" %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- set reasoning_content = reasoning_content|trim %} + {%- if loop.index0 > ns.last_query_index %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {%- if loop.first %} + {%- if content|trim %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n\n' }} + {%- endif %} + {%- else %} + {{- '\n\n\n' }} + {%- endif %} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- raise_exception('Unexpected message role.') }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/trl/chat_templates/qwen3_training.jinja b/trl/chat_templates/qwen3_training.jinja new file mode 100644 index 00000000000..978e58e3ebf --- /dev/null +++ b/trl/chat_templates/qwen3_training.jinja @@ -0,0 +1,88 @@ +{#- Training variant of the Qwen3 chat template (see qwen3.jinja for the original). + Modifications vs the original: + - {%- if '' in content %} → {%- if '' in content and '' in content %} + Always check for both tags to avoid edge cases where the model generates only one tag. + - Removed the loop.index0 > ns.last_query_index conditional; always include thinking block. + This makes the template prefix-preserving for the [user, assistant] → [user, assistant, tool] transition. +-#} +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content and '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file From 276559da2718c411c4df7f8ca397c7dfcfc6503b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:48:57 +0000 Subject: [PATCH 05/17] tools in dummy --- trl/chat_template_utils.py | 36 ++++++++++++++++++------------------ trl/trainer/grpo_trainer.py | 8 +++++++- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 3e86874816a..615722cac8c 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -16,6 +16,7 @@ from transformers import AddedToken, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer + _CHAT_TEMPLATES_DIR = Path(__file__).parent / "chat_templates" @@ -237,9 +238,9 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ Check whether the chat template preserves prefixes when applied. - A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. - This property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by - comparing tokenizations with and without tool messages appended. + A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. This + property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by comparing + tokenizations with and without tool messages appended. Args: tokenizer (`PreTrainedTokenizer`): @@ -249,22 +250,16 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: `bool`: `True` if the chat template preserves prefixes, `False` otherwise. """ + # Use the same dummy messages as _get_tool_suffix_ids to test the exact property it relies on. + dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] messages1 = [ - {"role": "user", "content": "What is 2 * 3?"}, - { - "role": "assistant", - "content": "", - "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], - }, + {"role": "user", "content": "dummy"}, + {"role": "assistant", "tool_calls": dummy_tool_calls}, ] messages2 = [ - {"role": "user", "content": "What is 2 * 3?"}, - { - "role": "assistant", - "content": "", - "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], - }, - {"role": "tool", "name": "multiply", "content": "6"}, + {"role": "user", "content": "dummy"}, + {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "tool", "name": "dummy", "content": "dummy"}, ] text1 = tokenizer.apply_chat_template(messages1, tokenize=False) @@ -300,7 +295,10 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") >>> messages1 = [ ... {"role": "user", "content": "What is 2 * 3?"}, - ... {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, + ... { + ... "role": "assistant", + ... "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], + ... }, ... ] >>> messages2 = messages1 + [ ... {"role": "tool", "name": "multiply", "content": "6"}, @@ -316,7 +314,9 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template) '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n' - >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template) + >>> tokenizer.apply_chat_template( + ... messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template + ... ) '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n' ``` """ diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0cd9c53aabe..d36e1237eb3 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1406,7 +1406,13 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): def _get_tool_suffix_ids(self, tool_messages): """Get token IDs for tool result formatting by using a minimal dummy conversation.""" - dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}] + dummy_messages = [ + {"role": "user", "content": "dummy"}, + { + "role": "assistant", + "tool_calls": [{"type": "function", "function": {"name": "dummy", "arguments": {}}}], + }, + ] if self._is_vlm: dummy_messages = prepare_multimodal_messages(dummy_messages) prefix_ids = self.processing_class.apply_chat_template( From 673c35d21fab7d533cdf493d5513b243f548e1c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:56:55 +0000 Subject: [PATCH 06/17] Add chat template files to MANIFEST.in --- MANIFEST.in | 2 ++ 1 file changed, 2 insertions(+) diff --git a/MANIFEST.in b/MANIFEST.in index 16cf649ce60..e843c590270 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,6 +3,8 @@ include CONTRIBUTING.md include README.md include trl/accelerate_configs/*.yaml include trl/templates/*.md +include trl/chat_templates/*.jinja +include trl/chat_templates/*.md include trl/skills/**/*.md recursive-exclude * __pycache__ prune tests From 604c476258df074c66729e72ec591ea9db702c4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 20:10:46 +0000 Subject: [PATCH 07/17] Enhance chat template handling to include tool call formatting in messages + async grpo --- tests/test_chat_template_utils.py | 24 ++++++++++++++++--- .../async_grpo/async_rollout_worker.py | 5 ++-- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 2f6cecbb262..b65231fb76c 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -148,13 +148,23 @@ def test_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") tokenizer.chat_template = textwrap.dedent(r""" {%- for message in messages %} + {%- set content = message.content if message.content is string else '' %} {%- if message.role == 'user' %} - {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} + {{- '<|im_start|>user\n' + content + '<|im_end|>\n' }} {%- elif message.role == 'assistant' %} - {{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }} + {{- '<|im_start|>assistant\n' + content }} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '' + tool_call.name + '' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} {%- elif message.role == 'tool' %} - {{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }} + {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} @@ -205,6 +215,14 @@ def test_non_prefix_preserving_template(self): {%- else %} {{- '<|im_start|>' + message.role + '\n' + content }} {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '' + tool_call.name + '' }} + {%- endfor %} + {%- endif %} {{- '<|im_end|>\n' }} {%- elif message.role == "tool" %} {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 3d7350d5a71..d4b64860692 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -555,9 +555,10 @@ async def _generate_one( def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int]: """Get token IDs for tool result formatting by using a minimal dummy conversation.""" + dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ - {"role": "user", "content": ""}, - {"role": "assistant", "content": ""}, + {"role": "user", "content": "dummy"}, + {"role": "assistant", "content": "dummy", "tool_calls": dummy_tool_calls}, ] prefix_ids = self.tokenizer.apply_chat_template( dummy_messages, From 83a7ef66d9d563d3b57c6cf56b8b25d4f17f6d9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 20:18:59 +0000 Subject: [PATCH 08/17] align grpo and async --- .../async_grpo/async_rollout_worker.py | 21 +++++++++++-------- trl/trainer/grpo_trainer.py | 6 ++---- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index d4b64860692..2d0e4b2b6c7 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -558,29 +558,32 @@ def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int] dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "content": "dummy", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "tool_calls": dummy_tool_calls}, ] prefix_ids = self.tokenizer.apply_chat_template( dummy_messages, - return_dict=False, - tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] + add_generation_prompt=False, + tokenize=True, chat_template=self.chat_template, + return_dict=False, **self.chat_template_kwargs, ) full_ids = self.tokenizer.apply_chat_template( dummy_messages + tool_messages, - return_dict=False, - chat_template=self.chat_template, add_generation_prompt=True, - tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] + tokenize=True, + chat_template=self.chat_template, + return_dict=False, **self.chat_template_kwargs, ) # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block. # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to - # EOS (not EOS + newline). - last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id) - prefix_ids = prefix_ids[: last_eos_idx + 1] + # EOS (not EOS + newline). Templates that don't use EOS as end-of-turn (e.g. Gemma uses + # ) skip this trimming. + eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id] + if eos_positions: + prefix_ids = prefix_ids[: eos_positions[-1] + 1] if full_ids[: len(prefix_ids)] != prefix_ids: raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d36e1237eb3..0d172190777 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1406,12 +1406,10 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): def _get_tool_suffix_ids(self, tool_messages): """Get token IDs for tool result formatting by using a minimal dummy conversation.""" + dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - { - "role": "assistant", - "tool_calls": [{"type": "function", "function": {"name": "dummy", "arguments": {}}}], - }, + {"role": "assistant", "tool_calls": dummy_tool_calls}, ] if self._is_vlm: dummy_messages = prepare_multimodal_messages(dummy_messages) From e5d7cdf70a0925dc120ab62c0c8078c28a5f70df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Apr 2026 14:04:54 +0000 Subject: [PATCH 09/17] revert no content --- tests/test_chat_template_utils.py | 9 ++++----- trl/chat_template_utils.py | 5 +++-- trl/experimental/async_grpo/async_rollout_worker.py | 2 +- trl/trainer/grpo_trainer.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index b65231fb76c..1d1840c623a 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -148,12 +148,11 @@ def test_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") tokenizer.chat_template = textwrap.dedent(r""" {%- for message in messages %} - {%- set content = message.content if message.content is string else '' %} {%- if message.role == 'user' %} - {{- '<|im_start|>user\n' + content + '<|im_end|>\n' }} + {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} {%- elif message.role == 'assistant' %} - {{- '<|im_start|>assistant\n' + content }} + {{- '<|im_start|>assistant\n' + message.content }} {%- if message.tool_calls %} {%- for tool_call in message.tool_calls %} {%- if tool_call.function %} @@ -164,7 +163,7 @@ def test_prefix_preserving_template(self): {%- endif %} {{- '<|im_end|>\n' }} {%- elif message.role == 'tool' %} - {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} + {{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} @@ -471,7 +470,7 @@ def test_parse_response_tool_call_without_arguments(self, tokenizer_name): tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}] messages = [ {"role": "user", "content": "Ping the service."}, - {"role": "assistant", "tool_calls": tool_calls}, + {"role": "assistant", "content": "", "tool_calls": tool_calls}, ] prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids text = tokenizer.apply_chat_template(messages).input_ids diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 615722cac8c..fae77c65f15 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -254,11 +254,11 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] messages1 = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, ] messages2 = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, {"role": "tool", "name": "dummy", "content": "dummy"}, ] @@ -297,6 +297,7 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: ... {"role": "user", "content": "What is 2 * 3?"}, ... { ... "role": "assistant", + ... "content": "", ... "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], ... }, ... ] diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 2d0e4b2b6c7..1a63a7d8e90 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -558,7 +558,7 @@ def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int] dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, ] prefix_ids = self.tokenizer.apply_chat_template( dummy_messages, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0d172190777..aa64eec2b5a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1409,7 +1409,7 @@ def _get_tool_suffix_ids(self, tool_messages): dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, ] if self._is_vlm: dummy_messages = prepare_multimodal_messages(dummy_messages) From a618809dbe7460f5ee5092545f91c7e29d641df1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Apr 2026 14:06:33 +0000 Subject: [PATCH 10/17] docstyle ignore --- tests/test_chat_template_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 1d1840c623a..50eb8af4aad 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -146,6 +146,7 @@ def test_add_response_schema(self, tokenizer_name): class TestIsChatTemplatePrefixPreserving: def test_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + # docstyle-ignore tokenizer.chat_template = textwrap.dedent(r""" {%- for message in messages %} From b18e39efd6a562012f441075217f592ca5647147 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Apr 2026 21:56:13 +0000 Subject: [PATCH 11/17] Add GPT-OSS tool calling support --- docs/source/grpo_trainer.md | 1 + tests/test_chat_template_utils.py | 11 +- trl/chat_template_utils.py | 45 ++++ trl/chat_templates/README.md | 4 + trl/chat_templates/gptoss.jinja | 331 ++++++++++++++++++++++++++++++ 5 files changed, 390 insertions(+), 2 deletions(-) create mode 100644 trl/chat_templates/gptoss.jinja diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 5176c3da744..a133e192e95 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -741,6 +741,7 @@ The returned images are automatically injected into the conversation and passed Tested with: - [**Gemma4**](https://huggingface.co/collections/google/gemma-4) — e.g., `google/gemma-4-E2B-it` +- [**GPT-OSS**](https://huggingface.co/collections/openai/gpt-oss) — e.g., `openai/gpt-4.1-nano` - [**Qwen3**](https://huggingface.co/collections/Qwen/qwen3) — e.g., `Qwen/Qwen3-0.6B` - [**Qwen3.5**](https://huggingface.co/collections/Qwen/qwen35) — e.g., `Qwen/Qwen3.5-2B` diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 155bf77d9b4..8d9f32c2874 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -113,6 +113,7 @@ def test_clone_with_sequence_classification_model(self): @pytest.mark.parametrize( "tokenizer_name", [ + pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"), ], @@ -396,6 +397,7 @@ def test_behavior_unchanged_generation_prompt_with_enable_thinking_false(self, t @pytest.mark.parametrize( "tokenizer_name", [ + pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"), pytest.param( @@ -430,8 +432,11 @@ def test_parse_response(self, tokenizer_name): assert parsed == messages[-1] def test_parse_response_with_reasoning_content(self, tokenizer_name): - if tokenizer_name == "trl-internal-testing/tiny-Gemma4ForConditionalGeneration": - pytest.skip("Gemma4 doesn't support inline reasoning_content.") + if tokenizer_name in ( + "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", + "trl-internal-testing/tiny-GptOssForCausalLM", + ): + pytest.skip("This model doesn't support inline reasoning_content.") tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) if getattr(tokenizer, "response_schema", None) is None: tokenizer = add_response_schema(tokenizer) @@ -499,6 +504,8 @@ def test_parse_response_tool_call_without_arguments(self, tokenizer_name): assert parsed == {"role": "assistant", "content": "", "tool_calls": tool_calls} def test_parse_response_multiple_tool_calls(self, tokenizer_name): + if tokenizer_name == "trl-internal-testing/tiny-GptOssForCausalLM": + pytest.skip("GPT-OSS template only renders one tool call per assistant message.") tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) if getattr(tokenizer, "response_schema", None) is None: tokenizer = add_response_schema(tokenizer) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index c85a7b45a04..b0b8c3ae272 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -183,6 +183,48 @@ def clone_chat_template( }, } +gptoss_schema = { + # Normalize final content to analysis format so both map to the same "content" group. + "x-regex-substitutions": [ + [r"<\|channel\|>final<\|message\|>(.*?)<\|return\|>", r"<|channel|>analysis<|message|>\1<|end|>"], + ], + "x-regex": r"^(?:<\|channel\|>analysis<\|message\|>(?P.*?)<\|end\|>(?:<\|start\|>assistant)?)?\s*(?Pto=functions\.\S+<\|channel\|>commentary json<\|message\|>.*?<\|call\|>)?$", + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string"}, + "tool_calls": { + "type": "array", + "x-regex-iterator": r"(to=functions\.\S+<\|channel\|>commentary json<\|message\|>.*?<\|call\|>)", + "items": { + # Convert "to=functions.NAME<|channel|>commentary json<|message|>ARGS<|call|>" + # into '{"name": "NAME", "arguments": ARGS}' so it can be parsed as JSON. + "x-regex-substitutions": [ + [r"to=functions\.(\S+)<\|channel\|>commentary json<\|message\|>(.*?)<\|call\|>", r'{"name": "\1", "arguments": \2}'], + ], + "x-parser": "json", + "x-parser-args": {"transform": "{type: 'function', function: @}"}, + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "arguments": { + "type": "object", + "additionalProperties": {}, + }, + }, + }, + }, + }, + }, + }, +} + +gptoss_chat_template = (_CHAT_TEMPLATES_DIR / "gptoss.jinja").read_text() + qwen3_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3.jinja").read_text() qwen3_5_chat_template_2b_and_below = (_CHAT_TEMPLATES_DIR / "qwen3_5_2b_and_below.jinja").read_text() @@ -220,6 +262,9 @@ def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: {'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]} ``` """ + if tokenizer.chat_template == gptoss_chat_template: + tokenizer.response_schema = gptoss_schema + return tokenizer if tokenizer.chat_template == qwen3_chat_template: tokenizer.response_schema = qwen3_schema return tokenizer diff --git a/trl/chat_templates/README.md b/trl/chat_templates/README.md index 5eb01cc4aca..4a7ac33e8c8 100644 --- a/trl/chat_templates/README.md +++ b/trl/chat_templates/README.md @@ -11,6 +11,10 @@ Jinja2 chat templates stored here serve two purposes: Used for identity comparison only. +### `gptoss.jinja` + +Original GPT-OSS chat template. + ### `qwen3.jinja` Original Qwen3 chat template. diff --git a/trl/chat_templates/gptoss.jinja b/trl/chat_templates/gptoss.jinja new file mode 100644 index 00000000000..dc7bb11927d --- /dev/null +++ b/trl/chat_templates/gptoss.jinja @@ -0,0 +1,331 @@ +{#- + In addition to the normal inputs of `messages` and `tools`, this template also accepts the + following kwargs: + - "builtin_tools": A list, can contain "browser" and/or "python". + - "model_identity": A string that optionally describes the model identity. + - "reasoning_effort": A string that describes the reasoning effort, defaults to "medium". + #} + +{#- Tool Definition Rendering ============================================== #} +{%- macro render_typescript_type(param_spec, required_params, is_nullable=false) -%} + {%- if param_spec.type == "array" -%} + {%- if param_spec['items'] -%} + {%- if param_spec['items']['type'] == "string" -%} + {{- "string[]" }} + {%- elif param_spec['items']['type'] == "number" -%} + {{- "number[]" }} + {%- elif param_spec['items']['type'] == "integer" -%} + {{- "number[]" }} + {%- elif param_spec['items']['type'] == "boolean" -%} + {{- "boolean[]" }} + {%- else -%} + {%- set inner_type = render_typescript_type(param_spec['items'], required_params) -%} + {%- if inner_type == "object | object" or inner_type|length > 50 -%} + {{- "any[]" }} + {%- else -%} + {{- inner_type + "[]" }} + {%- endif -%} + {%- endif -%} + {%- if param_spec.nullable -%} + {{- " | null" }} + {%- endif -%} + {%- else -%} + {{- "any[]" }} + {%- if param_spec.nullable -%} + {{- " | null" }} + {%- endif -%} + {%- endif -%} + {%- elif param_spec.type is defined and param_spec.type is iterable and param_spec.type is not string and param_spec.type is not mapping and param_spec.type[0] is defined -%} + {#- Handle array of types like ["object", "object"] from Union[dict, list] #} + {%- if param_spec.type | length > 1 -%} + {{- param_spec.type | join(" | ") }} + {%- else -%} + {{- param_spec.type[0] }} + {%- endif -%} + {%- elif param_spec.oneOf -%} + {#- Handle oneOf schemas - check for complex unions and fallback to any #} + {%- set has_object_variants = false -%} + {%- for variant in param_spec.oneOf -%} + {%- if variant.type == "object" -%} + {%- set has_object_variants = true -%} + {%- endif -%} + {%- endfor -%} + {%- if has_object_variants and param_spec.oneOf|length > 1 -%} + {{- "any" }} + {%- else -%} + {%- for variant in param_spec.oneOf -%} + {{- render_typescript_type(variant, required_params) -}} + {%- if variant.description %} + {{- "// " + variant.description }} + {%- endif -%} + {%- if variant.default is defined %} + {{ "// default: " + variant.default|tojson }} + {%- endif -%} + {%- if not loop.last %} + {{- " | " }} + {% endif -%} + {%- endfor -%} + {%- endif -%} + {%- elif param_spec.type == "string" -%} + {%- if param_spec.enum -%} + {{- '"' + param_spec.enum|join('" | "') + '"' -}} + {%- else -%} + {{- "string" }} + {%- if param_spec.nullable %} + {{- " | null" }} + {%- endif -%} + {%- endif -%} + {%- elif param_spec.type == "number" -%} + {{- "number" }} + {%- elif param_spec.type == "integer" -%} + {{- "number" }} + {%- elif param_spec.type == "boolean" -%} + {{- "boolean" }} + + {%- elif param_spec.type == "object" -%} + {%- if param_spec.properties -%} + {{- "{\n" }} + {%- for prop_name, prop_spec in param_spec.properties.items() -%} + {{- prop_name -}} + {%- if prop_name not in (param_spec.required or []) -%} + {{- "?" }} + {%- endif -%} + {{- ": " }} + {{ render_typescript_type(prop_spec, param_spec.required or []) }} + {%- if not loop.last -%} + {{-", " }} + {%- endif -%} + {%- endfor -%} + {{- "}" }} + {%- else -%} + {{- "object" }} + {%- endif -%} + {%- else -%} + {{- "any" }} + {%- endif -%} +{%- endmacro -%} + +{%- macro render_tool_namespace(namespace_name, tools) -%} + {{- "## " + namespace_name + "\n\n" }} + {{- "namespace " + namespace_name + " {\n\n" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- "// " + tool.description + "\n" }} + {{- "type "+ tool.name + " = " }} + {%- if tool.parameters and tool.parameters.properties %} + {{- "(_: {\n" }} + {%- for param_name, param_spec in tool.parameters.properties.items() %} + {%- if param_spec.description %} + {{- "// " + param_spec.description + "\n" }} + {%- endif %} + {{- param_name }} + {%- if param_name not in (tool.parameters.required or []) -%} + {{- "?" }} + {%- endif -%} + {{- ": " }} + {{- render_typescript_type(param_spec, tool.parameters.required or []) }} + {%- if param_spec.default is defined -%} + {%- if param_spec.enum %} + {{- ", // default: " + param_spec.default }} + {%- elif param_spec.oneOf %} + {{- "// default: " + param_spec.default }} + {%- else %} + {{- ", // default: " + param_spec.default|tojson }} + {%- endif -%} + {%- endif -%} + {%- if not loop.last %} + {{- ",\n" }} + {%- else %} + {{- ",\n" }} + {%- endif -%} + {%- endfor %} + {{- "}) => any;\n\n" }} + {%- else -%} + {{- "() => any;\n\n" }} + {%- endif -%} + {%- endfor %} + {{- "} // namespace " + namespace_name }} +{%- endmacro -%} + +{%- macro render_builtin_tools(browser_tool, python_tool) -%} + {%- if browser_tool %} + {{- "## browser\n\n" }} + {{- "// Tool for browsing.\n" }} + {{- "// The `cursor` appears in brackets before each browsing display: `[{cursor}]`.\n" }} + {{- "// Cite information from the tool using the following format:\n" }} + {{- "// `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.\n" }} + {{- "// Do not quote more than 10 words directly from the tool output.\n" }} + {{- "// sources=web (default: web)\n" }} + {{- "namespace browser {\n\n" }} + {{- "// Searches for information related to `query` and displays `topn` results.\n" }} + {{- "type search = (_: {\n" }} + {{- "query: string,\n" }} + {{- "topn?: number, // default: 10\n" }} + {{- "source?: string,\n" }} + {{- "}) => any;\n\n" }} + {{- "// Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.\n" }} + {{- "// Valid link ids are displayed with the formatting: `【{id}†.*】`.\n" }} + {{- "// If `cursor` is not provided, the most recent page is implied.\n" }} + {{- "// If `id` is a string, it is treated as a fully qualified URL associated with `source`.\n" }} + {{- "// If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.\n" }} + {{- "// Use this function without `id` to scroll to a new location of an opened page.\n" }} + {{- "type open = (_: {\n" }} + {{- "id?: number | string, // default: -1\n" }} + {{- "cursor?: number, // default: -1\n" }} + {{- "loc?: number, // default: -1\n" }} + {{- "num_lines?: number, // default: -1\n" }} + {{- "view_source?: boolean, // default: false\n" }} + {{- "source?: string,\n" }} + {{- "}) => any;\n\n" }} + {{- "// Finds exact matches of `pattern` in the current page, or the page given by `cursor`.\n" }} + {{- "type find = (_: {\n" }} + {{- "pattern: string,\n" }} + {{- "cursor?: number, // default: -1\n" }} + {{- "}) => any;\n\n" }} + {{- "} // namespace browser\n\n" }} + {%- endif -%} + + {%- if python_tool %} + {{- "## python\n\n" }} + {{- "Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\n\n" }} + {{- "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.\n\n" }} + {%- endif -%} +{%- endmacro -%} + +{#- System Message Construction ============================================ #} +{%- macro build_system_message() -%} + {%- if model_identity is not defined %} + {%- set model_identity = "You are ChatGPT, a large language model trained by OpenAI." %} + {%- endif %} + {{- model_identity + "\n" }} + {{- "Knowledge cutoff: 2024-06\n" }} + {{- "Current date: " + strftime_now("%Y-%m-%d") + "\n\n" }} + {%- if reasoning_effort is not defined %} + {%- set reasoning_effort = "medium" %} + {%- endif %} + {{- "Reasoning: " + reasoning_effort + "\n\n" }} + {%- if builtin_tools %} + {{- "# Tools\n\n" }} + {%- set available_builtin_tools = namespace(browser=false, python=false) %} + {%- for tool in builtin_tools %} + {%- if tool == "browser" %} + {%- set available_builtin_tools.browser = true %} + {%- elif tool == "python" %} + {%- set available_builtin_tools.python = true %} + {%- endif %} + {%- endfor %} + {{- render_builtin_tools(available_builtin_tools.browser, available_builtin_tools.python) }} + {%- endif -%} + {{- "# Valid channels: analysis, commentary, final. Channel must be included for every message." }} + {%- if tools -%} + {{- "\nCalls to these tools must go to the commentary channel: 'functions'." }} + {%- endif -%} +{%- endmacro -%} + +{#- Main Template Logic ================================================= #} +{#- Set defaults #} + +{#- Render system message #} +{{- "<|start|>system<|message|>" }} +{{- build_system_message() }} +{{- "<|end|>" }} + +{#- Extract developer message #} +{%- if messages[0].role == "developer" or messages[0].role == "system" %} + {%- set developer_message = messages[0].content %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set developer_message = "" %} + {%- set loop_messages = messages %} +{%- endif %} + +{#- Render developer message #} +{%- if developer_message or tools %} + {{- "<|start|>developer<|message|>" }} + {%- if developer_message %} + {{- "# Instructions\n\n" }} + {{- developer_message }} + {{- "\n\n" }} + {%- endif %} + {%- if tools -%} + {{- "# Tools\n\n" }} + {{- render_tool_namespace("functions", tools) }} + {%- endif -%} + {{- "<|end|>" }} +{%- endif %} + +{#- Render messages #} +{%- set last_tool_call = namespace(name=none) %} +{%- for message in loop_messages -%} + {#- At this point only assistant/user/tool messages should remain #} + {%- if message.role == 'assistant' -%} + {#- Checks to ensure the messages are being passed in the format we expect #} + {%- if "content" in message %} + {%- if "<|channel|>analysis<|message|>" in message.content or "<|channel|>final<|message|>" in message.content %} + {{- raise_exception("You have passed a message containing <|channel|> tags in the content field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.") }} + {%- endif %} + {%- endif %} + {%- if "thinking" in message %} + {%- if "<|channel|>analysis<|message|>" in message.thinking or "<|channel|>final<|message|>" in message.thinking %} + {{- raise_exception("You have passed a message containing <|channel|> tags in the thinking field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.") }} + {%- endif %} + {%- endif %} + {%- if "tool_calls" in message %} + {#- We need very careful handling here - we want to drop the tool call analysis message if the model #} + {#- has output a later <|final|> message, but otherwise we want to retain it. This is the only case #} + {#- when we render CoT/analysis messages in inference. #} + {%- set future_final_message = namespace(found=false) %} + {%- for future_message in loop_messages[loop.index:] %} + {%- if future_message.role == 'assistant' and "tool_calls" not in future_message %} + {%- set future_final_message.found = true %} + {%- endif %} + {%- endfor %} + {#- We assume max 1 tool call per message, and so we infer the tool call name #} + {#- in "tool" messages from the most recent assistant tool call name #} + {%- set tool_call = message.tool_calls[0] %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {%- if message.content and message.thinking %} + {{- raise_exception("Cannot pass both content and thinking in an assistant message with tool calls! Put the analysis message in one or the other, but not both.") }} + {%- elif message.content and not future_final_message.found %} + {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }} + {%- elif message.thinking and not future_final_message.found %} + {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }} + {%- endif %} + {{- "<|start|>assistant to=" }} + {{- "functions." + tool_call.name + "<|channel|>commentary " }} + {{- (tool_call.content_type if tool_call.content_type is defined else "json") + "<|message|>" }} + {{- tool_call.arguments|tojson }} + {{- "<|call|>" }} + {%- set last_tool_call.name = tool_call.name %} + {%- elif loop.last and not add_generation_prompt %} + {#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #} + {#- This is a situation that should only occur in training, never in inference. #} + {%- if "thinking" in message %} + {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }} + {%- endif %} + {#- <|return|> indicates the end of generation, but <|end|> does not #} + {#- <|return|> should never be an input to the model, but we include it as the final token #} + {#- when training, so the model learns to emit it. #} + {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|return|>" }} + {%- else %} + {#- CoT is dropped during all previous turns, so we never render it for inference #} + {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }} + {%- set last_tool_call.name = none %} + {%- endif %} + {%- elif message.role == 'tool' -%} + {%- if last_tool_call.name is none %} + {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }} + {%- endif %} + {{- "<|start|>functions." + last_tool_call.name }} + {{- " to=assistant<|channel|>commentary<|message|>" + message.content|tojson + "<|end|>" }} + {%- elif message.role == 'user' -%} + {{- "<|start|>user<|message|>" + message.content + "<|end|>" }} + {%- endif -%} +{%- endfor -%} + +{#- Generation prompt #} +{%- if add_generation_prompt -%} +<|start|>assistant +{%- endif -%} \ No newline at end of file From 71ce5a062465753fd9f522c3cf8ce000ca8cba64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Apr 2026 21:58:21 +0000 Subject: [PATCH 12/17] fix gpt oss --- docs/source/grpo_trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index a133e192e95..37d55c3dfce 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -741,7 +741,7 @@ The returned images are automatically injected into the conversation and passed Tested with: - [**Gemma4**](https://huggingface.co/collections/google/gemma-4) — e.g., `google/gemma-4-E2B-it` -- [**GPT-OSS**](https://huggingface.co/collections/openai/gpt-oss) — e.g., `openai/gpt-4.1-nano` +- [**GPT-OSS**](https://huggingface.co/collections/openai/gpt-oss) — e.g., `openai/gpt-oss-20b` - [**Qwen3**](https://huggingface.co/collections/Qwen/qwen3) — e.g., `Qwen/Qwen3-0.6B` - [**Qwen3.5**](https://huggingface.co/collections/Qwen/qwen35) — e.g., `Qwen/Qwen3.5-2B` From 8f1ad1e076c67d50e9595dfbd3c15a9748af4aca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Apr 2026 22:20:08 +0000 Subject: [PATCH 13/17] Update tool suffix ID retrieval to use actual tool names for GPT-OSS compatibility --- trl/experimental/async_grpo/async_rollout_worker.py | 5 ++++- trl/trainer/grpo_trainer.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 1a63a7d8e90..49d67c7fa3a 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -555,7 +555,10 @@ async def _generate_one( def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int]: """Get token IDs for tool result formatting by using a minimal dummy conversation.""" - dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] + # Use the real tool name instead of a dummy: some templates (e.g. GPT-OSS) derive the tool response + # header from the assistant's tool call name. + tool_name = tool_messages[0]["name"] + dummy_tool_calls = [{"type": "function", "function": {"name": tool_name, "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 45457afc262..840794b6a16 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1404,7 +1404,9 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): def _get_tool_suffix_ids(self, tool_messages): """Get token IDs for tool result formatting by using a minimal dummy conversation.""" - dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] + # Use the real tool name instead of a dummy: some templates (e.g. GPT-OSS) derive the tool response + # header from the assistant's tool call name. + dummy_tool_calls = [{"type": "function", "function": {"name": tool_messages[0]["name"], "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, From 9b9771dfab2f94e8c68c948c1760228906c8c0f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Apr 2026 22:26:57 +0000 Subject: [PATCH 14/17] style --- trl/chat_template_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index b0b8c3ae272..53a8012cc6a 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -200,7 +200,10 @@ def clone_chat_template( # Convert "to=functions.NAME<|channel|>commentary json<|message|>ARGS<|call|>" # into '{"name": "NAME", "arguments": ARGS}' so it can be parsed as JSON. "x-regex-substitutions": [ - [r"to=functions\.(\S+)<\|channel\|>commentary json<\|message\|>(.*?)<\|call\|>", r'{"name": "\1", "arguments": \2}'], + [ + r"to=functions\.(\S+)<\|channel\|>commentary json<\|message\|>(.*?)<\|call\|>", + r'{"name": "\1", "arguments": \2}', + ], ], "x-parser": "json", "x-parser-args": {"transform": "{type: 'function', function: @}"}, From b3f448145be95017fbf6a774fffe0877f32aaf83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Apr 2026 22:35:43 +0000 Subject: [PATCH 15/17] align async --- trl/experimental/async_grpo/async_rollout_worker.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 49d67c7fa3a..b2bfea6de6f 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -557,8 +557,7 @@ def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int] """Get token IDs for tool result formatting by using a minimal dummy conversation.""" # Use the real tool name instead of a dummy: some templates (e.g. GPT-OSS) derive the tool response # header from the assistant's tool call name. - tool_name = tool_messages[0]["name"] - dummy_tool_calls = [{"type": "function", "function": {"name": tool_name, "arguments": {}}}] + dummy_tool_calls = [{"type": "function", "function": {"name": tool_messages[0]["name"], "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, From 325360263893807ff12e8b05186622673a52f934 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 7 Apr 2026 14:49:54 +0000 Subject: [PATCH 16/17] style --- trl/experimental/async_grpo/async_rollout_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 966183811e6..58379395191 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -560,7 +560,7 @@ def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int] dummy_tool_calls = [{"type": "function", "function": {"name": tool_messages[0]["name"], "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - { + { "role": "assistant", # "content" is required here because VLM processors crash on tokenize=True without it # (KeyError in processing_utils.py). See huggingface/transformers#45290. From 392dece4f3d9217dae969d26936a69d436c1e578 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 8 Apr 2026 21:45:09 -0400 Subject: [PATCH 17/17] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- tests/test_chat_template_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 2a68e47c4da..4c4193d8b93 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -563,7 +563,7 @@ def test_parse_response_tool_call_without_arguments(self, tokenizer_name): tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}] messages = [ {"role": "user", "content": "Ping the service."}, - {"role": "assistant", "content": "", "tool_calls": tool_calls}, + {"role": "assistant", "tool_calls": tool_calls}, ] prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids text = tokenizer.apply_chat_template(messages).input_ids