From 4b3aa51d70e6fa8dc65aa1d5ce94509e5e535e17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 18:44:21 +0000 Subject: [PATCH 01/15] Narrow prefix-preserving check to the actual requirement --- tests/test_chat_template_utils.py | 17 ++++----- trl/chat_template_utils.py | 57 +++++++++---------------------- 2 files changed, 23 insertions(+), 51 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 6459e3d8fef..2f6cecbb262 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -153,6 +153,8 @@ def test_prefix_preserving_template(self): {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} {%- elif message.role == 'assistant' %} {{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }} + {%- elif message.role == 'tool' %} + {{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} @@ -164,8 +166,9 @@ def test_prefix_preserving_template(self): def test_non_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") - # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part is - # only present for last assistant message, which makes it non-prefix-preserving. + # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part (even + # empty) is only present for last assistant message, which makes it non-prefix-preserving: appending a tool + # message changes the earlier output. # docstyle-ignore tokenizer.chat_template = textwrap.dedent(r""" {%- if messages[0].role == 'system' %} @@ -203,6 +206,8 @@ def test_non_prefix_preserving_template(self): {{- '<|im_start|>' + message.role + '\n' + content }} {%- endif %} {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} {%- if add_generation_prompt %} @@ -218,14 +223,6 @@ def test_non_prefix_preserving_template(self): "tokenizer_name", [ pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), - pytest.param( - "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", - id="qwen35", - marks=pytest.mark.skipif( - Version(transformers.__version__) < Version("5.0.0"), - reason="Qwen3.5 tokenizer requires transformers>=5.0.0", - ), - ), ], ) class TestGetTrainingChatTemplate: diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 90be3fa3b57..eeb3f74e794 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -630,6 +630,10 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ Check whether the chat template preserves prefixes when applied. + A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. + This property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by + comparing tokenizations with and without tool messages appended. + Args: tokenizer (`PreTrainedTokenizer`): Tokenizer instance to check. @@ -640,22 +644,18 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ messages1 = [ {"role": "user", "content": "What color is the sky?"}, - ] - messages2 = [ - {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ] - messages3 = [ + messages2 = [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, - {"role": "user", "content": "And at night?"}, + {"role": "tool", "name": "say_hi", "content": "Hi!"}, ] - text1 = tokenizer.apply_chat_template(messages1, tokenize=False, add_generation_prompt=True) - text2 = tokenizer.apply_chat_template(messages2, tokenize=False) - text3 = tokenizer.apply_chat_template(messages3, tokenize=False) + text1 = tokenizer.apply_chat_template(messages1, tokenize=False) + text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) - return text2.startswith(text1) and text3.startswith(text2) + return text2.startswith(text1) # Modifications: @@ -749,33 +749,12 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: {%- endif %}""" -# Modifications: -# - {%- if '' in content %} -# + {%- if '' in content and '' in content %} -# Always check for both tags to avoid edge cases where the model generates only one tag, which would otherwise be parsed incorrectly -# - {{- '<|im_start|>' + message.role + '\n' + content }} -# + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} -# Always include thinking block during training. It's important to have a prefix-preserving template. -def _patch_qwen3_5_training_template(template: str) -> str: - return template.replace( - "{%- if '' in content %}", - "{%- if '' in content and '' in content %}", - ).replace( - "{{- '<|im_start|>' + message.role + '\\n' + content }}", - "{{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content + '\\n\\n\\n' + content }}", - ) - - -qwen3_5_training_chat_template_2b_and_below = _patch_qwen3_5_training_template(qwen3_5_chat_template_2b_and_below) -qwen3_5_training_chat_template_4b_and_above = _patch_qwen3_5_training_template(qwen3_5_chat_template_4b_and_above) - - def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: r""" Get a prefix-preserving chat template for training, if needed. - If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 and - Qwen3.5 supported). Otherwise, returns `None`. + If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 + supported). Otherwise, returns `None`. Args: tokenizer (`PreTrainedTokenizer`): @@ -799,21 +778,21 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> messages2 = [ ... {"role": "user", "content": "What color is the sky?"}, ... {"role": "assistant", "content": "It is blue."}, - ... {"role": "user", "content": "And at night?"}, + ... {"role": "tool", "name": "say_hi", "content": "Hi!"}, ... ] >>> tokenizer.apply_chat_template(messages1, tokenize=False) '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' - >>> tokenizer.apply_chat_template(messages2, tokenize=False) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' + >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) + '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\n\nHi!\n<|im_end|>\n<|im_start|>assistant\n' >>> # ^ think tags missing >>> chat_template = get_training_chat_template(tokenizer) >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template) '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' - >>> tokenizer.apply_chat_template(messages2, tokenize=False, chat_template=chat_template) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' + >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template) + '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\n\nHi!\n<|im_end|>\n<|im_start|>assistant\n' ``` """ # First check if patching is needed @@ -822,10 +801,6 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: if tokenizer.chat_template == qwen3_chat_template: return qwen3_training_chat_template - if tokenizer.chat_template == qwen3_5_chat_template_2b_and_below: - return qwen3_5_training_chat_template_2b_and_below - if tokenizer.chat_template == qwen3_5_chat_template_4b_and_above: - return qwen3_5_training_chat_template_4b_and_above else: raise ValueError( "The tokenizer's chat template is not prefix-preserving and patching is not supported for this template. " From 730070bcf14f58ae60c9646ec379441d4d112c01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:00:02 +0000 Subject: [PATCH 02/15] Update chat template examples to use multiplication function calls --- trl/chat_template_utils.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index eeb3f74e794..21bbe598685 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -643,13 +643,13 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: `True` if the chat template preserves prefixes, `False` otherwise. """ messages1 = [ - {"role": "user", "content": "What color is the sky?"}, - {"role": "assistant", "content": "It is blue."}, + {"role": "user", "content": "What is 2 * 3?"}, + {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, ] messages2 = [ - {"role": "user", "content": "What color is the sky?"}, - {"role": "assistant", "content": "It is blue."}, - {"role": "tool", "name": "say_hi", "content": "Hi!"}, + {"role": "user", "content": "What is 2 * 3?"}, + {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, + {"role": "tool", "name": "multiply", "content": "6"}, ] text1 = tokenizer.apply_chat_template(messages1, tokenize=False) @@ -772,27 +772,25 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") >>> messages1 = [ - ... {"role": "user", "content": "What color is the sky?"}, - ... {"role": "assistant", "content": "It is blue."}, + ... {"role": "user", "content": "What is 2 * 3?"}, + ... {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, ... ] - >>> messages2 = [ - ... {"role": "user", "content": "What color is the sky?"}, - ... {"role": "assistant", "content": "It is blue."}, - ... {"role": "tool", "name": "say_hi", "content": "Hi!"}, + >>> messages2 = messages1 + [ + ... {"role": "tool", "name": "multiply", "content": "6"}, ... ] >>> tokenizer.apply_chat_template(messages1, tokenize=False) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n' >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\n\nHi!\n<|im_end|>\n<|im_start|>assistant\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n' - >>> # ^ think tags missing + >>> # ^ think tags missing >>> chat_template = get_training_chat_template(tokenizer) >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n' >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\n\nHi!\n<|im_end|>\n<|im_start|>assistant\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n' ``` """ # First check if patching is needed From 4622d77234bf46d010b782fcebab1b2ed2d2e32a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:01:42 +0000 Subject: [PATCH 03/15] style --- trl/chat_template_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 21bbe598685..6b9d3c24e2b 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -644,11 +644,19 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ messages1 = [ {"role": "user", "content": "What is 2 * 3?"}, - {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], + }, ] messages2 = [ {"role": "user", "content": "What is 2 * 3?"}, - {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], + }, {"role": "tool", "name": "multiply", "content": "6"}, ] From 08d4c518ec29b2c9d3e0f2fff5eae90ca9aebdf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:33:46 +0000 Subject: [PATCH 04/15] Move chat templates from inline strings to `.jinja` files --- trl/chat_template_utils.py | 503 +----------------- trl/chat_templates/README.md | 50 ++ trl/chat_templates/qwen3.jinja | 89 ++++ trl/chat_templates/qwen3_5_2b_and_below.jinja | 154 ++++++ trl/chat_templates/qwen3_5_4b_and_above.jinja | 154 ++++++ trl/chat_templates/qwen3_training.jinja | 88 +++ 6 files changed, 546 insertions(+), 492 deletions(-) create mode 100644 trl/chat_templates/README.md create mode 100644 trl/chat_templates/qwen3.jinja create mode 100644 trl/chat_templates/qwen3_5_2b_and_below.jinja create mode 100644 trl/chat_templates/qwen3_5_4b_and_above.jinja create mode 100644 trl/chat_templates/qwen3_training.jinja diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 6b9d3c24e2b..3e86874816a 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + from transformers import AddedToken, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer +_CHAT_TEMPLATES_DIR = Path(__file__).parent / "chat_templates" + def clone_chat_template( model: PreTrainedModel, @@ -178,409 +182,12 @@ def clone_chat_template( }, } -# docstyle-ignore -qwen3_chat_template = r"""{%- if tools %} - {{- '<|im_start|>system\n' }} - {%- if messages[0].role == 'system' %} - {{- messages[0].content + '\n\n' }} - {%- endif %} - {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} - {%- for tool in tools %} - {{- "\n" }} - {{- tool | tojson }} - {%- endfor %} - {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} -{%- else %} - {%- if messages[0].role == 'system' %} - {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} - {%- endif %} -{%- endif %} -{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} -{%- for message in messages[::-1] %} - {%- set index = (messages|length - 1) - loop.index0 %} - {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} - {%- set ns.multi_step_tool = false %} - {%- set ns.last_query_index = index %} - {%- endif %} -{%- endfor %} -{%- for message in messages %} - {%- if message.content is string %} - {%- set content = message.content %} - {%- else %} - {%- set content = '' %} - {%- endif %} - {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} - {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} - {%- elif message.role == "assistant" %} - {%- set reasoning_content = '' %} - {%- if message.reasoning_content is string %} - {%- set reasoning_content = message.reasoning_content %} - {%- else %} - {%- if '' in content %} - {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} - {%- set content = content.split('')[-1].lstrip('\n') %} - {%- endif %} - {%- endif %} - {%- if loop.index0 > ns.last_query_index %} - {%- if loop.last or (not loop.last and reasoning_content) %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} - {%- else %} - {{- '<|im_start|>' + message.role + '\n' + content }} - {%- endif %} - {%- else %} - {{- '<|im_start|>' + message.role + '\n' + content }} - {%- endif %} - {%- if message.tool_calls %} - {%- for tool_call in message.tool_calls %} - {%- if (loop.first and content) or (not loop.first) %} - {{- '\n' }} - {%- endif %} - {%- if tool_call.function %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {{- '\n{"name": "' }} - {{- tool_call.name }} - {{- '", "arguments": ' }} - {%- if tool_call.arguments is string %} - {{- tool_call.arguments }} - {%- else %} - {{- tool_call.arguments | tojson }} - {%- endif %} - {{- '}\n' }} - {%- endfor %} - {%- endif %} - {{- '<|im_end|>\n' }} - {%- elif message.role == "tool" %} - {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} - {{- '<|im_start|>user' }} - {%- endif %} - {{- '\n\n' }} - {{- content }} - {{- '\n' }} - {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} - {{- '<|im_end|>\n' }} - {%- endif %} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|im_start|>assistant\n' }} - {%- if enable_thinking is defined and enable_thinking is false %} - {{- '\n\n\n\n' }} - {%- endif %} -{%- endif %}""" - -# docstyle-ignore -qwen3_5_chat_template_2b_and_below = r"""{%- set image_count = namespace(value=0) %} -{%- set video_count = namespace(value=0) %} -{%- macro render_content(content, do_vision_count, is_system_content=false) %} - {%- if content is string %} - {{- content }} - {%- elif content is iterable and content is not mapping %} - {%- for item in content %} - {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} - {%- if is_system_content %} - {{- raise_exception('System message cannot contain images.') }} - {%- endif %} - {%- if do_vision_count %} - {%- set image_count.value = image_count.value + 1 %} - {%- endif %} - {%- if add_vision_id %} - {{- 'Picture ' ~ image_count.value ~ ': ' }} - {%- endif %} - {{- '<|vision_start|><|image_pad|><|vision_end|>' }} - {%- elif 'video' in item or item.type == 'video' %} - {%- if is_system_content %} - {{- raise_exception('System message cannot contain videos.') }} - {%- endif %} - {%- if do_vision_count %} - {%- set video_count.value = video_count.value + 1 %} - {%- endif %} - {%- if add_vision_id %} - {{- 'Video ' ~ video_count.value ~ ': ' }} - {%- endif %} - {{- '<|vision_start|><|video_pad|><|vision_end|>' }} - {%- elif 'text' in item %} - {{- item.text }} - {%- else %} - {{- raise_exception('Unexpected item type in content.') }} - {%- endif %} - {%- endfor %} - {%- elif content is none or content is undefined %} - {{- '' }} - {%- else %} - {{- raise_exception('Unexpected content type.') }} - {%- endif %} -{%- endmacro %} -{%- if not messages %} - {{- raise_exception('No messages provided.') }} -{%- endif %} -{%- if tools and tools is iterable and tools is not mapping %} - {{- '<|im_start|>system\n' }} - {{- "# Tools\n\nYou have access to the following functions:\n\n" }} - {%- for tool in tools %} - {{- "\n" }} - {{- tool | tojson }} - {%- endfor %} - {{- "\n" }} - {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} - {%- if messages[0].role == 'system' %} - {%- set content = render_content(messages[0].content, false, true)|trim %} - {%- if content %} - {{- '\n\n' + content }} - {%- endif %} - {%- endif %} - {{- '<|im_end|>\n' }} -{%- else %} - {%- if messages[0].role == 'system' %} - {%- set content = render_content(messages[0].content, false, true)|trim %} - {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} - {%- endif %} -{%- endif %} -{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} -{%- for message in messages[::-1] %} - {%- set index = (messages|length - 1) - loop.index0 %} - {%- if ns.multi_step_tool and message.role == "user" %} - {%- set content = render_content(message.content, false)|trim %} - {%- if not(content.startswith('') and content.endswith('')) %} - {%- set ns.multi_step_tool = false %} - {%- set ns.last_query_index = index %} - {%- endif %} - {%- endif %} -{%- endfor %} -{%- if ns.multi_step_tool %} - {{- raise_exception('No user query found in messages.') }} -{%- endif %} -{%- for message in messages %} - {%- set content = render_content(message.content, true)|trim %} - {%- if message.role == "system" %} - {%- if not loop.first %} - {{- raise_exception('System message must be at the beginning.') }} - {%- endif %} - {%- elif message.role == "user" %} - {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} - {%- elif message.role == "assistant" %} - {%- set reasoning_content = '' %} - {%- if message.reasoning_content is string %} - {%- set reasoning_content = message.reasoning_content %} - {%- else %} - {%- if '' in content %} - {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} - {%- set content = content.split('')[-1].lstrip('\n') %} - {%- endif %} - {%- endif %} - {%- set reasoning_content = reasoning_content|trim %} - {%- if loop.index0 > ns.last_query_index %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} - {%- else %} - {{- '<|im_start|>' + message.role + '\n' + content }} - {%- endif %} - {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} - {%- for tool_call in message.tool_calls %} - {%- if tool_call.function is defined %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {%- if loop.first %} - {%- if content|trim %} - {{- '\n\n\n\n' }} - {%- else %} - {{- '\n\n' }} - {%- endif %} - {%- else %} - {{- '\n\n\n' }} - {%- endif %} - {%- if tool_call.arguments is defined %} - {%- for args_name, args_value in tool_call.arguments|items %} - {{- '\n' }} - {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} - {{- args_value }} - {{- '\n\n' }} - {%- endfor %} - {%- endif %} - {{- '\n' }} - {%- endfor %} - {%- endif %} - {{- '<|im_end|>\n' }} - {%- elif message.role == "tool" %} - {%- if loop.previtem and loop.previtem.role != "tool" %} - {{- '<|im_start|>user' }} - {%- endif %} - {{- '\n\n' }} - {{- content }} - {{- '\n' }} - {%- if not loop.last and loop.nextitem.role != "tool" %} - {{- '<|im_end|>\n' }} - {%- elif loop.last %} - {{- '<|im_end|>\n' }} - {%- endif %} - {%- else %} - {{- raise_exception('Unexpected message role.') }} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|im_start|>assistant\n' }} - {%- if enable_thinking is defined and enable_thinking is true %} - {{- '\n' }} - {%- else %} - {{- '\n\n\n\n' }} - {%- endif %} -{%- endif %}""" - - -# docstyle-ignore -qwen3_5_chat_template_4b_and_above = r"""{%- set image_count = namespace(value=0) %} -{%- set video_count = namespace(value=0) %} -{%- macro render_content(content, do_vision_count, is_system_content=false) %} - {%- if content is string %} - {{- content }} - {%- elif content is iterable and content is not mapping %} - {%- for item in content %} - {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} - {%- if is_system_content %} - {{- raise_exception('System message cannot contain images.') }} - {%- endif %} - {%- if do_vision_count %} - {%- set image_count.value = image_count.value + 1 %} - {%- endif %} - {%- if add_vision_id %} - {{- 'Picture ' ~ image_count.value ~ ': ' }} - {%- endif %} - {{- '<|vision_start|><|image_pad|><|vision_end|>' }} - {%- elif 'video' in item or item.type == 'video' %} - {%- if is_system_content %} - {{- raise_exception('System message cannot contain videos.') }} - {%- endif %} - {%- if do_vision_count %} - {%- set video_count.value = video_count.value + 1 %} - {%- endif %} - {%- if add_vision_id %} - {{- 'Video ' ~ video_count.value ~ ': ' }} - {%- endif %} - {{- '<|vision_start|><|video_pad|><|vision_end|>' }} - {%- elif 'text' in item %} - {{- item.text }} - {%- else %} - {{- raise_exception('Unexpected item type in content.') }} - {%- endif %} - {%- endfor %} - {%- elif content is none or content is undefined %} - {{- '' }} - {%- else %} - {{- raise_exception('Unexpected content type.') }} - {%- endif %} -{%- endmacro %} -{%- if not messages %} - {{- raise_exception('No messages provided.') }} -{%- endif %} -{%- if tools and tools is iterable and tools is not mapping %} - {{- '<|im_start|>system\n' }} - {{- "# Tools\n\nYou have access to the following functions:\n\n" }} - {%- for tool in tools %} - {{- "\n" }} - {{- tool | tojson }} - {%- endfor %} - {{- "\n" }} - {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} - {%- if messages[0].role == 'system' %} - {%- set content = render_content(messages[0].content, false, true)|trim %} - {%- if content %} - {{- '\n\n' + content }} - {%- endif %} - {%- endif %} - {{- '<|im_end|>\n' }} -{%- else %} - {%- if messages[0].role == 'system' %} - {%- set content = render_content(messages[0].content, false, true)|trim %} - {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} - {%- endif %} -{%- endif %} -{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} -{%- for message in messages[::-1] %} - {%- set index = (messages|length - 1) - loop.index0 %} - {%- if ns.multi_step_tool and message.role == "user" %} - {%- set content = render_content(message.content, false)|trim %} - {%- if not(content.startswith('') and content.endswith('')) %} - {%- set ns.multi_step_tool = false %} - {%- set ns.last_query_index = index %} - {%- endif %} - {%- endif %} -{%- endfor %} -{%- if ns.multi_step_tool %} - {{- raise_exception('No user query found in messages.') }} -{%- endif %} -{%- for message in messages %} - {%- set content = render_content(message.content, true)|trim %} - {%- if message.role == "system" %} - {%- if not loop.first %} - {{- raise_exception('System message must be at the beginning.') }} - {%- endif %} - {%- elif message.role == "user" %} - {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} - {%- elif message.role == "assistant" %} - {%- set reasoning_content = '' %} - {%- if message.reasoning_content is string %} - {%- set reasoning_content = message.reasoning_content %} - {%- else %} - {%- if '' in content %} - {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} - {%- set content = content.split('')[-1].lstrip('\n') %} - {%- endif %} - {%- endif %} - {%- set reasoning_content = reasoning_content|trim %} - {%- if loop.index0 > ns.last_query_index %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} - {%- else %} - {{- '<|im_start|>' + message.role + '\n' + content }} - {%- endif %} - {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} - {%- for tool_call in message.tool_calls %} - {%- if tool_call.function is defined %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {%- if loop.first %} - {%- if content|trim %} - {{- '\n\n\n\n' }} - {%- else %} - {{- '\n\n' }} - {%- endif %} - {%- else %} - {{- '\n\n\n' }} - {%- endif %} - {%- if tool_call.arguments is defined %} - {%- for args_name, args_value in tool_call.arguments|items %} - {{- '\n' }} - {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} - {{- args_value }} - {{- '\n\n' }} - {%- endfor %} - {%- endif %} - {{- '\n' }} - {%- endfor %} - {%- endif %} - {{- '<|im_end|>\n' }} - {%- elif message.role == "tool" %} - {%- if loop.previtem and loop.previtem.role != "tool" %} - {{- '<|im_start|>user' }} - {%- endif %} - {{- '\n\n' }} - {{- content }} - {{- '\n' }} - {%- if not loop.last and loop.nextitem.role != "tool" %} - {{- '<|im_end|>\n' }} - {%- elif loop.last %} - {{- '<|im_end|>\n' }} - {%- endif %} - {%- else %} - {{- raise_exception('Unexpected message role.') }} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|im_start|>assistant\n' }} - {%- if enable_thinking is defined and enable_thinking is false %} - {{- '\n\n\n\n' }} - {%- else %} - {{- '\n' }} - {%- endif %} -{%- endif %}""" +qwen3_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3.jinja").read_text() + +qwen3_5_chat_template_2b_and_below = (_CHAT_TEMPLATES_DIR / "qwen3_5_2b_and_below.jinja").read_text() + + +qwen3_5_chat_template_4b_and_above = (_CHAT_TEMPLATES_DIR / "qwen3_5_4b_and_above.jinja").read_text() def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: @@ -666,95 +273,7 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: return text2.startswith(text1) -# Modifications: -# - {%- if '' in content %} -# + {%- if '' in content and '' in content %} -# Always check for both tags to avoid edge cases where the model generates only one tag, which would otherwise be parsed incorrectly -# - {%- if loop.index0 > ns.last_query_index %} ... {%- endif %} -# + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} -# Always include thinking block during training. It's important to have a prefix-preserving template. -# docstyle-ignore -qwen3_training_chat_template = r"""{%- if tools %} - {{- '<|im_start|>system\n' }} - {%- if messages[0].role == 'system' %} - {{- messages[0].content + '\n\n' }} - {%- endif %} - {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} - {%- for tool in tools %} - {{- "\n" }} - {{- tool | tojson }} - {%- endfor %} - {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} -{%- else %} - {%- if messages[0].role == 'system' %} - {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} - {%- endif %} -{%- endif %} -{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} -{%- for message in messages[::-1] %} - {%- set index = (messages|length - 1) - loop.index0 %} - {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} - {%- set ns.multi_step_tool = false %} - {%- set ns.last_query_index = index %} - {%- endif %} -{%- endfor %} -{%- for message in messages %} - {%- if message.content is string %} - {%- set content = message.content %} - {%- else %} - {%- set content = '' %} - {%- endif %} - {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} - {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} - {%- elif message.role == "assistant" %} - {%- set reasoning_content = '' %} - {%- if message.reasoning_content is string %} - {%- set reasoning_content = message.reasoning_content %} - {%- else %} - {%- if '' in content and '' in content %} - {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} - {%- set content = content.split('')[-1].lstrip('\n') %} - {%- endif %} - {%- endif %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} - {%- if message.tool_calls %} - {%- for tool_call in message.tool_calls %} - {%- if (loop.first and content) or (not loop.first) %} - {{- '\n' }} - {%- endif %} - {%- if tool_call.function %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {{- '\n{"name": "' }} - {{- tool_call.name }} - {{- '", "arguments": ' }} - {%- if tool_call.arguments is string %} - {{- tool_call.arguments }} - {%- else %} - {{- tool_call.arguments | tojson }} - {%- endif %} - {{- '}\n' }} - {%- endfor %} - {%- endif %} - {{- '<|im_end|>\n' }} - {%- elif message.role == "tool" %} - {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} - {{- '<|im_start|>user' }} - {%- endif %} - {{- '\n\n' }} - {{- content }} - {{- '\n' }} - {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} - {{- '<|im_end|>\n' }} - {%- endif %} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|im_start|>assistant\n' }} - {%- if enable_thinking is defined and enable_thinking is false %} - {{- '\n\n\n\n' }} - {%- endif %} -{%- endif %}""" +qwen3_training_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3_training.jinja").read_text() def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: diff --git a/trl/chat_templates/README.md b/trl/chat_templates/README.md new file mode 100644 index 00000000000..5eb01cc4aca --- /dev/null +++ b/trl/chat_templates/README.md @@ -0,0 +1,50 @@ +# Chat Templates + +Jinja2 chat templates stored here serve two purposes: + +1. **Identity comparison**: detecting which model is being used (by comparing `tokenizer.chat_template` against known templates) to add the appropriate response schema (`add_response_schema`) or swap in a training template (`get_training_chat_template`). +2. **Training patches**: modified templates that fix prefix-preservation issues for the GRPO tool call loop. + +**Why prefix-preserving?** The GRPO tool call loop extracts tool response formatting tokens by comparing tokenizations with and without tool messages appended (`_get_tool_suffix_ids`). This requires the chat template to be *prefix-preserving*: appending messages must not change how earlier messages are rendered. + +## Original templates + +Used for identity comparison only. + +### `qwen3.jinja` + +Original Qwen3 chat template. + +### `qwen3_5_2b_and_below.jinja` / `qwen3_5_4b_and_above.jinja` + +Original Qwen3.5 chat templates. + +## Training templates + +Patched templates that fix prefix-preservation issues. Swapped in at init when tools are enabled. + +### `qwen3_training.jinja` + +Patched Qwen3 template that always includes thinking blocks, making it prefix-preserving. Diff vs `qwen3.jinja`: + +Require both `` and `` to be present before parsing, to avoid incorrect splitting when the model generates only one tag: + +```diff +- {%- if '' in content %} ++ {%- if '' in content and '' in content %} +``` + +Always include the thinking block regardless of message position. The original conditionally omits it based on `loop.last`, which changes the assistant rendering when a tool message is appended — breaking prefix-preservation: + +```diff +- {%- if loop.index0 > ns.last_query_index %} +- {%- if loop.last or (not loop.last and reasoning_content) %} +- {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} +- {%- else %} +- {{- '<|im_start|>' + message.role + '\n' + content }} +- {%- endif %} +- {%- else %} +- {{- '<|im_start|>' + message.role + '\n' + content }} +- {%- endif %} ++ {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} +``` diff --git a/trl/chat_templates/qwen3.jinja b/trl/chat_templates/qwen3.jinja new file mode 100644 index 00000000000..01be9b307da --- /dev/null +++ b/trl/chat_templates/qwen3.jinja @@ -0,0 +1,89 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/trl/chat_templates/qwen3_5_2b_and_below.jinja b/trl/chat_templates/qwen3_5_2b_and_below.jinja new file mode 100644 index 00000000000..0ef09f214ea --- /dev/null +++ b/trl/chat_templates/qwen3_5_2b_and_below.jinja @@ -0,0 +1,154 @@ +{%- set image_count = namespace(value=0) %} +{%- set video_count = namespace(value=0) %} +{%- macro render_content(content, do_vision_count, is_system_content=false) %} + {%- if content is string %} + {{- content }} + {%- elif content is iterable and content is not mapping %} + {%- for item in content %} + {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain images.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set image_count.value = image_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Picture ' ~ image_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- elif 'video' in item or item.type == 'video' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain videos.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set video_count.value = video_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Video ' ~ video_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|video_pad|><|vision_end|>' }} + {%- elif 'text' in item %} + {{- item.text }} + {%- else %} + {{- raise_exception('Unexpected item type in content.') }} + {%- endif %} + {%- endfor %} + {%- elif content is none or content is undefined %} + {{- '' }} + {%- else %} + {{- raise_exception('Unexpected content type.') }} + {%- endif %} +{%- endmacro %} +{%- if not messages %} + {{- raise_exception('No messages provided.') }} +{%- endif %} +{%- if tools and tools is iterable and tools is not mapping %} + {{- '<|im_start|>system\n' }} + {{- "# Tools\n\nYou have access to the following functions:\n\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {%- if content %} + {{- '\n\n' + content }} + {%- endif %} + {%- endif %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" %} + {%- set content = render_content(message.content, false)|trim %} + {%- if not(content.startswith('') and content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if ns.multi_step_tool %} + {{- raise_exception('No user query found in messages.') }} +{%- endif %} +{%- for message in messages %} + {%- set content = render_content(message.content, true)|trim %} + {%- if message.role == "system" %} + {%- if not loop.first %} + {{- raise_exception('System message must be at the beginning.') }} + {%- endif %} + {%- elif message.role == "user" %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- set reasoning_content = reasoning_content|trim %} + {%- if loop.index0 > ns.last_query_index %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {%- if loop.first %} + {%- if content|trim %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n\n' }} + {%- endif %} + {%- else %} + {{- '\n\n\n' }} + {%- endif %} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- raise_exception('Unexpected message role.') }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is true %} + {{- '\n' }} + {%- else %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/trl/chat_templates/qwen3_5_4b_and_above.jinja b/trl/chat_templates/qwen3_5_4b_and_above.jinja new file mode 100644 index 00000000000..a585dec894e --- /dev/null +++ b/trl/chat_templates/qwen3_5_4b_and_above.jinja @@ -0,0 +1,154 @@ +{%- set image_count = namespace(value=0) %} +{%- set video_count = namespace(value=0) %} +{%- macro render_content(content, do_vision_count, is_system_content=false) %} + {%- if content is string %} + {{- content }} + {%- elif content is iterable and content is not mapping %} + {%- for item in content %} + {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain images.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set image_count.value = image_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Picture ' ~ image_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- elif 'video' in item or item.type == 'video' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain videos.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set video_count.value = video_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Video ' ~ video_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|video_pad|><|vision_end|>' }} + {%- elif 'text' in item %} + {{- item.text }} + {%- else %} + {{- raise_exception('Unexpected item type in content.') }} + {%- endif %} + {%- endfor %} + {%- elif content is none or content is undefined %} + {{- '' }} + {%- else %} + {{- raise_exception('Unexpected content type.') }} + {%- endif %} +{%- endmacro %} +{%- if not messages %} + {{- raise_exception('No messages provided.') }} +{%- endif %} +{%- if tools and tools is iterable and tools is not mapping %} + {{- '<|im_start|>system\n' }} + {{- "# Tools\n\nYou have access to the following functions:\n\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {%- if content %} + {{- '\n\n' + content }} + {%- endif %} + {%- endif %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" %} + {%- set content = render_content(message.content, false)|trim %} + {%- if not(content.startswith('') and content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if ns.multi_step_tool %} + {{- raise_exception('No user query found in messages.') }} +{%- endif %} +{%- for message in messages %} + {%- set content = render_content(message.content, true)|trim %} + {%- if message.role == "system" %} + {%- if not loop.first %} + {{- raise_exception('System message must be at the beginning.') }} + {%- endif %} + {%- elif message.role == "user" %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- set reasoning_content = reasoning_content|trim %} + {%- if loop.index0 > ns.last_query_index %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {%- if loop.first %} + {%- if content|trim %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n\n' }} + {%- endif %} + {%- else %} + {{- '\n\n\n' }} + {%- endif %} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- raise_exception('Unexpected message role.') }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/trl/chat_templates/qwen3_training.jinja b/trl/chat_templates/qwen3_training.jinja new file mode 100644 index 00000000000..978e58e3ebf --- /dev/null +++ b/trl/chat_templates/qwen3_training.jinja @@ -0,0 +1,88 @@ +{#- Training variant of the Qwen3 chat template (see qwen3.jinja for the original). + Modifications vs the original: + - {%- if '' in content %} → {%- if '' in content and '' in content %} + Always check for both tags to avoid edge cases where the model generates only one tag. + - Removed the loop.index0 > ns.last_query_index conditional; always include thinking block. + This makes the template prefix-preserving for the [user, assistant] → [user, assistant, tool] transition. +-#} +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content and '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file From 276559da2718c411c4df7f8ca397c7dfcfc6503b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:48:57 +0000 Subject: [PATCH 05/15] tools in dummy --- trl/chat_template_utils.py | 36 ++++++++++++++++++------------------ trl/trainer/grpo_trainer.py | 8 +++++++- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 3e86874816a..615722cac8c 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -16,6 +16,7 @@ from transformers import AddedToken, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer + _CHAT_TEMPLATES_DIR = Path(__file__).parent / "chat_templates" @@ -237,9 +238,9 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ Check whether the chat template preserves prefixes when applied. - A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. - This property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by - comparing tokenizations with and without tool messages appended. + A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. This + property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by comparing + tokenizations with and without tool messages appended. Args: tokenizer (`PreTrainedTokenizer`): @@ -249,22 +250,16 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: `bool`: `True` if the chat template preserves prefixes, `False` otherwise. """ + # Use the same dummy messages as _get_tool_suffix_ids to test the exact property it relies on. + dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] messages1 = [ - {"role": "user", "content": "What is 2 * 3?"}, - { - "role": "assistant", - "content": "", - "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], - }, + {"role": "user", "content": "dummy"}, + {"role": "assistant", "tool_calls": dummy_tool_calls}, ] messages2 = [ - {"role": "user", "content": "What is 2 * 3?"}, - { - "role": "assistant", - "content": "", - "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], - }, - {"role": "tool", "name": "multiply", "content": "6"}, + {"role": "user", "content": "dummy"}, + {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "tool", "name": "dummy", "content": "dummy"}, ] text1 = tokenizer.apply_chat_template(messages1, tokenize=False) @@ -300,7 +295,10 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") >>> messages1 = [ ... {"role": "user", "content": "What is 2 * 3?"}, - ... {"role": "assistant", "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}]}, + ... { + ... "role": "assistant", + ... "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], + ... }, ... ] >>> messages2 = messages1 + [ ... {"role": "tool", "name": "multiply", "content": "6"}, @@ -316,7 +314,9 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template) '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n' - >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template) + >>> tokenizer.apply_chat_template( + ... messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template + ... ) '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n' ``` """ diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0cd9c53aabe..d36e1237eb3 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1406,7 +1406,13 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): def _get_tool_suffix_ids(self, tool_messages): """Get token IDs for tool result formatting by using a minimal dummy conversation.""" - dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}] + dummy_messages = [ + {"role": "user", "content": "dummy"}, + { + "role": "assistant", + "tool_calls": [{"type": "function", "function": {"name": "dummy", "arguments": {}}}], + }, + ] if self._is_vlm: dummy_messages = prepare_multimodal_messages(dummy_messages) prefix_ids = self.processing_class.apply_chat_template( From 673c35d21fab7d533cdf493d5513b243f548e1c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 19:56:55 +0000 Subject: [PATCH 06/15] Add chat template files to MANIFEST.in --- MANIFEST.in | 2 ++ 1 file changed, 2 insertions(+) diff --git a/MANIFEST.in b/MANIFEST.in index 16cf649ce60..e843c590270 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,6 +3,8 @@ include CONTRIBUTING.md include README.md include trl/accelerate_configs/*.yaml include trl/templates/*.md +include trl/chat_templates/*.jinja +include trl/chat_templates/*.md include trl/skills/**/*.md recursive-exclude * __pycache__ prune tests From 604c476258df074c66729e72ec591ea9db702c4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 20:10:46 +0000 Subject: [PATCH 07/15] Enhance chat template handling to include tool call formatting in messages + async grpo --- tests/test_chat_template_utils.py | 24 ++++++++++++++++--- .../async_grpo/async_rollout_worker.py | 5 ++-- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 2f6cecbb262..b65231fb76c 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -148,13 +148,23 @@ def test_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") tokenizer.chat_template = textwrap.dedent(r""" {%- for message in messages %} + {%- set content = message.content if message.content is string else '' %} {%- if message.role == 'user' %} - {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} + {{- '<|im_start|>user\n' + content + '<|im_end|>\n' }} {%- elif message.role == 'assistant' %} - {{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }} + {{- '<|im_start|>assistant\n' + content }} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '' + tool_call.name + '' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} {%- elif message.role == 'tool' %} - {{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }} + {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} @@ -205,6 +215,14 @@ def test_non_prefix_preserving_template(self): {%- else %} {{- '<|im_start|>' + message.role + '\n' + content }} {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '' + tool_call.name + '' }} + {%- endfor %} + {%- endif %} {{- '<|im_end|>\n' }} {%- elif message.role == "tool" %} {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 3d7350d5a71..d4b64860692 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -555,9 +555,10 @@ async def _generate_one( def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int]: """Get token IDs for tool result formatting by using a minimal dummy conversation.""" + dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ - {"role": "user", "content": ""}, - {"role": "assistant", "content": ""}, + {"role": "user", "content": "dummy"}, + {"role": "assistant", "content": "dummy", "tool_calls": dummy_tool_calls}, ] prefix_ids = self.tokenizer.apply_chat_template( dummy_messages, From 83a7ef66d9d563d3b57c6cf56b8b25d4f17f6d9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 5 Apr 2026 20:18:59 +0000 Subject: [PATCH 08/15] align grpo and async --- .../async_grpo/async_rollout_worker.py | 21 +++++++++++-------- trl/trainer/grpo_trainer.py | 6 ++---- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index d4b64860692..2d0e4b2b6c7 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -558,29 +558,32 @@ def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int] dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "content": "dummy", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "tool_calls": dummy_tool_calls}, ] prefix_ids = self.tokenizer.apply_chat_template( dummy_messages, - return_dict=False, - tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] + add_generation_prompt=False, + tokenize=True, chat_template=self.chat_template, + return_dict=False, **self.chat_template_kwargs, ) full_ids = self.tokenizer.apply_chat_template( dummy_messages + tool_messages, - return_dict=False, - chat_template=self.chat_template, add_generation_prompt=True, - tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] + tokenize=True, + chat_template=self.chat_template, + return_dict=False, **self.chat_template_kwargs, ) # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block. # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to - # EOS (not EOS + newline). - last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id) - prefix_ids = prefix_ids[: last_eos_idx + 1] + # EOS (not EOS + newline). Templates that don't use EOS as end-of-turn (e.g. Gemma uses + # ) skip this trimming. + eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id] + if eos_positions: + prefix_ids = prefix_ids[: eos_positions[-1] + 1] if full_ids[: len(prefix_ids)] != prefix_ids: raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d36e1237eb3..0d172190777 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1406,12 +1406,10 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): def _get_tool_suffix_ids(self, tool_messages): """Get token IDs for tool result formatting by using a minimal dummy conversation.""" + dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - { - "role": "assistant", - "tool_calls": [{"type": "function", "function": {"name": "dummy", "arguments": {}}}], - }, + {"role": "assistant", "tool_calls": dummy_tool_calls}, ] if self._is_vlm: dummy_messages = prepare_multimodal_messages(dummy_messages) From e5d7cdf70a0925dc120ab62c0c8078c28a5f70df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Apr 2026 14:04:54 +0000 Subject: [PATCH 09/15] revert no content --- tests/test_chat_template_utils.py | 9 ++++----- trl/chat_template_utils.py | 5 +++-- trl/experimental/async_grpo/async_rollout_worker.py | 2 +- trl/trainer/grpo_trainer.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index b65231fb76c..1d1840c623a 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -148,12 +148,11 @@ def test_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") tokenizer.chat_template = textwrap.dedent(r""" {%- for message in messages %} - {%- set content = message.content if message.content is string else '' %} {%- if message.role == 'user' %} - {{- '<|im_start|>user\n' + content + '<|im_end|>\n' }} + {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} {%- elif message.role == 'assistant' %} - {{- '<|im_start|>assistant\n' + content }} + {{- '<|im_start|>assistant\n' + message.content }} {%- if message.tool_calls %} {%- for tool_call in message.tool_calls %} {%- if tool_call.function %} @@ -164,7 +163,7 @@ def test_prefix_preserving_template(self): {%- endif %} {{- '<|im_end|>\n' }} {%- elif message.role == 'tool' %} - {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} + {{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} @@ -471,7 +470,7 @@ def test_parse_response_tool_call_without_arguments(self, tokenizer_name): tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}] messages = [ {"role": "user", "content": "Ping the service."}, - {"role": "assistant", "tool_calls": tool_calls}, + {"role": "assistant", "content": "", "tool_calls": tool_calls}, ] prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids text = tokenizer.apply_chat_template(messages).input_ids diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 615722cac8c..fae77c65f15 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -254,11 +254,11 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] messages1 = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, ] messages2 = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, {"role": "tool", "name": "dummy", "content": "dummy"}, ] @@ -297,6 +297,7 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: ... {"role": "user", "content": "What is 2 * 3?"}, ... { ... "role": "assistant", + ... "content": "", ... "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], ... }, ... ] diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 2d0e4b2b6c7..1a63a7d8e90 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -558,7 +558,7 @@ def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int] dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, ] prefix_ids = self.tokenizer.apply_chat_template( dummy_messages, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0d172190777..aa64eec2b5a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1409,7 +1409,7 @@ def _get_tool_suffix_ids(self, tool_messages): dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ {"role": "user", "content": "dummy"}, - {"role": "assistant", "tool_calls": dummy_tool_calls}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, ] if self._is_vlm: dummy_messages = prepare_multimodal_messages(dummy_messages) From a618809dbe7460f5ee5092545f91c7e29d641df1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Apr 2026 14:06:33 +0000 Subject: [PATCH 10/15] docstyle ignore --- tests/test_chat_template_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 1d1840c623a..50eb8af4aad 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -146,6 +146,7 @@ def test_add_response_schema(self, tokenizer_name): class TestIsChatTemplatePrefixPreserving: def test_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + # docstyle-ignore tokenizer.chat_template = textwrap.dedent(r""" {%- for message in messages %} From 7b7f5d15eca0b0866c1afee75df66bc637cfd8fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 7 Apr 2026 14:55:40 +0000 Subject: [PATCH 11/15] revert old modif --- tests/test_chat_template_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 155bf77d9b4..2cdd45b85fb 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -490,7 +490,7 @@ def test_parse_response_tool_call_without_arguments(self, tokenizer_name): tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}] messages = [ {"role": "user", "content": "Ping the service."}, - {"role": "assistant", "content": "", "tool_calls": tool_calls}, + {"role": "assistant", "tool_calls": tool_calls}, ] prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids text = tokenizer.apply_chat_template(messages).input_ids From 3c2fc8e38f55746d3b2a7616c08e29f3a280dd10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 7 Apr 2026 17:34:16 +0000 Subject: [PATCH 12/15] Add `{% generation %}` support to training chat templates --- docs/source/sft_trainer.md | 2 +- tests/test_chat_template_utils.py | 35 ++++++++++++++++++++++++- trl/chat_template_utils.py | 2 +- trl/chat_templates/README.md | 10 ++++--- trl/chat_templates/qwen3_training.jinja | 4 +++ trl/trainer/sft_trainer.py | 12 +++++++-- 6 files changed, 57 insertions(+), 8 deletions(-) diff --git a/docs/source/sft_trainer.md b/docs/source/sft_trainer.md index b99dd9f2e59..089bfc595ff 100644 --- a/docs/source/sft_trainer.md +++ b/docs/source/sft_trainer.md @@ -169,7 +169,7 @@ training_args = SFTConfig(assistant_only_loss=True) ![train_on_assistant](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/train_on_assistant.png) > [!WARNING] -> This functionality is only available for chat templates that support returning the assistant tokens mask via the `{% generation %}` and `{% endgeneration %}` keywords. For an example of such a template, see [HugggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja#L76-L82). +> This functionality requires the chat template to include `{% generation %}` and `{% endgeneration %}` keywords. For known model families (e.g. Qwen3), TRL automatically patches the template when `assistant_only_loss=True`. For other models, check that your chat template includes these keywords — see [HuggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja#L76-L82) for an example. ### Train on completion only diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 2cdd45b85fb..c61f23facc5 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -246,7 +246,6 @@ def test_non_prefix_preserving_template(self): class TestGetTrainingChatTemplate: def test_new_chat_template_is_prefix_preserving(self, tokenizer_name): tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - assert is_chat_template_prefix_preserving(tokenizer) is False tokenizer.chat_template = get_training_chat_template(tokenizer) assert is_chat_template_prefix_preserving(tokenizer) is True @@ -392,6 +391,40 @@ def test_behavior_unchanged_generation_prompt_with_enable_thinking_false(self, t ) assert before == after + def test_assistant_masks(self, tokenizer_name): + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + messages = [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ] + chat_template = get_training_chat_template(tokenizer) + result = tokenizer.apply_chat_template( + messages, chat_template=chat_template, return_assistant_tokens_mask=True, return_dict=True + ) + masks = result["assistant_masks"] + assert 1 in masks + # The first tokens (user turn) should not be masked + assert masks[0] == 0 + # The last tokens (assistant turn ending with <|im_end|>) should be masked + assert masks[-1] == 1 + + def test_assistant_masks_multi_turn(self, tokenizer_name): + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + messages = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + {"role": "user", "content": "Bye"}, + {"role": "assistant", "content": "Goodbye!"}, + ] + chat_template = get_training_chat_template(tokenizer) + result = tokenizer.apply_chat_template( + messages, chat_template=chat_template, return_assistant_tokens_mask=True, return_dict=True + ) + masks = result["assistant_masks"] + # Should have two masked regions (two assistant turns): 0→1, 1→0, 0→1 + transitions = sum(1 for i in range(1, len(masks)) if masks[i] != masks[i - 1]) + assert transitions == 3 + @pytest.mark.parametrize( "tokenizer_name", diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index c85a7b45a04..a576eacb910 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -322,7 +322,7 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: ``` """ # First check if patching is needed - if is_chat_template_prefix_preserving(tokenizer): + if is_chat_template_prefix_preserving(tokenizer) and "{% generation %}" in tokenizer.chat_template: return None # No patching needed if tokenizer.chat_template == qwen3_chat_template: diff --git a/trl/chat_templates/README.md b/trl/chat_templates/README.md index 5eb01cc4aca..d7e2d84eaea 100644 --- a/trl/chat_templates/README.md +++ b/trl/chat_templates/README.md @@ -3,10 +3,12 @@ Jinja2 chat templates stored here serve two purposes: 1. **Identity comparison**: detecting which model is being used (by comparing `tokenizer.chat_template` against known templates) to add the appropriate response schema (`add_response_schema`) or swap in a training template (`get_training_chat_template`). -2. **Training patches**: modified templates that fix prefix-preservation issues for the GRPO tool call loop. +2. **Training patches**: modified templates that fix training-specific issues (prefix-preservation for GRPO, `{% generation %}` markers for SFT assistant-only loss). **Why prefix-preserving?** The GRPO tool call loop extracts tool response formatting tokens by comparing tokenizations with and without tool messages appended (`_get_tool_suffix_ids`). This requires the chat template to be *prefix-preserving*: appending messages must not change how earlier messages are rendered. +**Why generation-tagged?** SFT with `assistant_only_loss=True` requires the chat template to include `{% generation %}` / `{% endgeneration %}` markers around assistant output, so `return_assistant_tokens_mask=True` can produce correct masks. Most model templates don't include these markers natively. + ## Original templates Used for identity comparison only. @@ -21,11 +23,11 @@ Original Qwen3.5 chat templates. ## Training templates -Patched templates that fix prefix-preservation issues. Swapped in at init when tools are enabled. +Patched templates that fix training-specific issues. Swapped in at init when tools are enabled (GRPO) or when `assistant_only_loss=True` (SFT). ### `qwen3_training.jinja` -Patched Qwen3 template that always includes thinking blocks, making it prefix-preserving. Diff vs `qwen3.jinja`: +Patched Qwen3 template. Diff vs `qwen3.jinja`: Require both `` and `` to be present before parsing, to avoid incorrect splitting when the model generates only one tag: @@ -48,3 +50,5 @@ Always include the thinking block regardless of message position. The original c - {%- endif %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} ``` + +Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. diff --git a/trl/chat_templates/qwen3_training.jinja b/trl/chat_templates/qwen3_training.jinja index 978e58e3ebf..3287c3d2ce3 100644 --- a/trl/chat_templates/qwen3_training.jinja +++ b/trl/chat_templates/qwen3_training.jinja @@ -4,6 +4,8 @@ Always check for both tags to avoid edge cases where the model generates only one tag. - Removed the loop.index0 > ns.last_query_index conditional; always include thinking block. This makes the template prefix-preserving for the [user, assistant] → [user, assistant, tool] transition. + - Added {% generation %} / {% endgeneration %} around assistant message output to support + assistant-only loss masking in SFT training. -#} {%- if tools %} {{- '<|im_start|>system\n' }} @@ -47,6 +49,7 @@ {%- set content = content.split('')[-1].lstrip('\n') %} {%- endif %} {%- endif %} + {% generation %} {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} {%- if message.tool_calls %} {%- for tool_call in message.tool_calls %} @@ -68,6 +71,7 @@ {%- endfor %} {%- endif %} {{- '<|im_end|>\n' }} + {%- endgeneration %} {%- elif message.role == "tool" %} {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} {{- '<|im_start|>user' }} diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 1807520dc72..273337c6889 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -43,7 +43,7 @@ from transformers.trainer_utils import EvalPrediction from transformers.utils import is_peft_available -from ..chat_template_utils import clone_chat_template +from ..chat_template_utils import clone_chat_template, get_training_chat_template from ..data_utils import ( apply_chat_template, is_conversational, @@ -922,6 +922,12 @@ def __init__( "supported for conversational datasets." ) + # When assistant_only_loss is enabled, swap in a training chat template with {% generation %} markers + # if the current template doesn't already have them. + if args.assistant_only_loss and "{% generation %}" not in processing_class.chat_template: + self.chat_template = get_training_chat_template(processing_class) + else: + self.chat_template = processing_class.chat_template # Dataset if self.padding_free and not args.packing and args.max_length is not None and not self._is_vision_dataset: raise ValueError( @@ -1038,7 +1044,9 @@ def _tokenize( if isinstance(input, list): # conversational: list of message dicts if self._is_vlm: input = prepare_multimodal_messages(input) - result = processing_class.apply_chat_template(input, tokenize=True, return_dict=True, **kwargs) + result = processing_class.apply_chat_template( + input, tokenize=True, return_dict=True, chat_template=self.chat_template, **kwargs + ) else: # non-conversational: plain text string result = processing_class(text=input) # VLMs emit a batch dimension even for single examples; unwrap it From d8877242f1103a963ea663db60dda8398b5e2cf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 7 Apr 2026 17:55:16 +0000 Subject: [PATCH 13/15] conditionally call this in grpo --- trl/chat_template_utils.py | 10 ++++++---- trl/experimental/async_grpo/async_rollout_worker.py | 13 +++++++++++-- trl/trainer/grpo_trainer.py | 9 +++++++-- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index a576eacb910..6f8fcfb1e01 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -273,10 +273,11 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: r""" - Get a prefix-preserving chat template for training, if needed. + Get a training-compatible chat template, if needed. - If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 - supported). Otherwise, returns `None`. + Returns a patched chat template that is prefix-preserving and includes `{%% generation %%}` / `{%% endgeneration + %%}` markers for assistant-only loss masking. Returns `None` if the tokenizer's template already satisfies both + requirements. Currently Qwen3 is supported. Args: tokenizer (`PreTrainedTokenizer`): @@ -329,7 +330,8 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: return qwen3_training_chat_template else: raise ValueError( - "The tokenizer's chat template is not prefix-preserving and patching is not supported for this template. " + "The tokenizer's chat template is not training-compatible (missing prefix-preservation or " + "`{% generation %}` markers) and patching is not supported for this template. " "Please manually modify the tokenizer's chat template for training." ) diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index df32627641a..179868f8734 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -28,7 +28,12 @@ from datasets import Dataset from transformers import AutoTokenizer -from trl.chat_template_utils import add_response_schema, get_training_chat_template, parse_response +from trl.chat_template_utils import ( + add_response_schema, + get_training_chat_template, + is_chat_template_prefix_preserving, + parse_response, +) from trl.import_utils import is_vllm_available from trl.trainer.utils import print_prompt_completions_sample @@ -162,7 +167,11 @@ def __init__( self.num_completions_to_print = num_completions_to_print self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer = add_response_schema(self.tokenizer) - self.chat_template = get_training_chat_template(self.tokenizer) + self.chat_template = ( + get_training_chat_template(self.tokenizer) + if not is_chat_template_prefix_preserving(self.tokenizer) + else None + ) self._groups_to_score: asyncio.Queue[RolloutGroup | None] = asyncio.Queue(maxsize=16) self._total_completion_tokens = 0 diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d775baefd37..b660f6cfff3 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -56,7 +56,12 @@ ) from transformers.utils import is_peft_available, is_rich_available -from ..chat_template_utils import add_response_schema, get_training_chat_template, parse_response +from ..chat_template_utils import ( + add_response_schema, + get_training_chat_template, + is_chat_template_prefix_preserving, + parse_response, +) from ..data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages from ..extras.profiling import profiling_context, profiling_decorator from ..generation.vllm_generation import VLLMGeneration @@ -515,7 +520,7 @@ def __init__( processing_class = add_response_schema(processing_class) # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template # isn't, we replace it at initialization with a training-safe, prefix-preserving template. - if self.tools: + if self.tools and not is_chat_template_prefix_preserving(processing_class): self.chat_template = get_training_chat_template(processing_class) else: self.chat_template = None From 9d4d57f9cc3a9e0266755f0781ab4857810de7f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 7 Apr 2026 18:01:24 +0000 Subject: [PATCH 14/15] align --- trl/experimental/async_grpo/async_rollout_worker.py | 11 ++++++----- trl/trainer/sft_trainer.py | 3 ++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 179868f8734..c6d4dd4c7bc 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -167,11 +167,12 @@ def __init__( self.num_completions_to_print = num_completions_to_print self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer = add_response_schema(self.tokenizer) - self.chat_template = ( - get_training_chat_template(self.tokenizer) - if not is_chat_template_prefix_preserving(self.tokenizer) - else None - ) + # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template + # isn't, we replace it at initialization with a training-safe, prefix-preserving template. + if self.tools and not is_chat_template_prefix_preserving(self.tokenizer): + self.chat_template = get_training_chat_template(self.tokenizer) + else: + self.chat_template = None self._groups_to_score: asyncio.Queue[RolloutGroup | None] = asyncio.Queue(maxsize=16) self._total_completion_tokens = 0 diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 273337c6889..fe0b3fbe106 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -927,7 +927,8 @@ def __init__( if args.assistant_only_loss and "{% generation %}" not in processing_class.chat_template: self.chat_template = get_training_chat_template(processing_class) else: - self.chat_template = processing_class.chat_template + self.chat_template = None + # Dataset if self.padding_free and not args.packing and args.max_length is not None and not self._is_vision_dataset: raise ValueError( From b33bc667b78a8d7ad711902f90d8b1af1c1b3d03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 8 Apr 2026 14:23:19 +0000 Subject: [PATCH 15/15] Refactor Qwen3 training template: Adjust generation block and output formatting --- trl/chat_templates/qwen3_training.jinja | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/trl/chat_templates/qwen3_training.jinja b/trl/chat_templates/qwen3_training.jinja index 3287c3d2ce3..decc48dc6b6 100644 --- a/trl/chat_templates/qwen3_training.jinja +++ b/trl/chat_templates/qwen3_training.jinja @@ -49,8 +49,9 @@ {%- set content = content.split('')[-1].lstrip('\n') %} {%- endif %} {%- endif %} - {% generation %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {{- '<|im_start|>' + message.role + '\n' }} + {%- generation %} + {{- '\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} {%- if message.tool_calls %} {%- for tool_call in message.tool_calls %} {%- if (loop.first and content) or (not loop.first) %}