Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4b3aa51
Narrow prefix-preserving check to the actual requirement
qgallouedec Apr 5, 2026
0894910
Merge branch 'main' into narrow-prefix-preserving-check
qgallouedec Apr 5, 2026
730070b
Update chat template examples to use multiplication function calls
qgallouedec Apr 5, 2026
4622d77
style
qgallouedec Apr 5, 2026
08d4c51
Move chat templates from inline strings to `.jinja` files
qgallouedec Apr 5, 2026
276559d
tools in dummy
qgallouedec Apr 5, 2026
673c35d
Add chat template files to MANIFEST.in
qgallouedec Apr 5, 2026
604c476
Enhance chat template handling to include tool call formatting in mes…
qgallouedec Apr 5, 2026
83a7ef6
align grpo and async
qgallouedec Apr 5, 2026
0f28384
Merge branch 'main' into chat-templates-files
qgallouedec Apr 6, 2026
e5d7cdf
revert no content
qgallouedec Apr 6, 2026
a618809
docstyle ignore
qgallouedec Apr 6, 2026
a0b81b1
Merge branch 'main' into chat-templates-files
qgallouedec Apr 6, 2026
67ab0af
Merge branch 'main' into chat-templates-files
qgallouedec Apr 6, 2026
63ec7d3
Merge branch 'main' into chat-templates-files
qgallouedec Apr 7, 2026
c838146
Merge branch 'main' into chat-templates-files
qgallouedec Apr 7, 2026
7b7f5d1
revert old modif
qgallouedec Apr 7, 2026
e37fe00
Merge branch 'main' into chat-templates-files
qgallouedec Apr 7, 2026
3c2fc8e
Add `{% generation %}` support to training chat templates
qgallouedec Apr 7, 2026
d887724
conditionally call this in grpo
qgallouedec Apr 7, 2026
9d4d57f
align
qgallouedec Apr 7, 2026
632e8b1
Merge branch 'main' into generation-tags
qgallouedec Apr 8, 2026
b33bc66
Refactor Qwen3 training template: Adjust generation block and output …
qgallouedec Apr 8, 2026
8820a29
Merge branch 'main' into generation-tags
qgallouedec Apr 9, 2026
31e640f
Merge branch 'main' into generation-tags
qgallouedec Apr 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ include CONTRIBUTING.md
include README.md
include trl/accelerate_configs/*.yaml
include trl/templates/*.md
include trl/chat_templates/*.jinja
include trl/chat_templates/*.md
include trl/skills/**/*.md
recursive-exclude * __pycache__
prune tests
2 changes: 1 addition & 1 deletion docs/source/sft_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ training_args = SFTConfig(assistant_only_loss=True)
![train_on_assistant](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/train_on_assistant.png)

> [!WARNING]
> This functionality is only available for chat templates that support returning the assistant tokens mask via the `{% generation %}` and `{% endgeneration %}` keywords. For an example of such a template, see [HugggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja#L76-L82).
> This functionality requires the chat template to include `{% generation %}` and `{% endgeneration %}` keywords. For known model families (e.g. Qwen3), TRL automatically patches the template when `assistant_only_loss=True`. For other models, check that your chat template includes these keywords — see [HuggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja#L76-L82) for an example.

### Train on completion only

Expand Down
35 changes: 34 additions & 1 deletion tests/test_chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ def test_non_prefix_preserving_template(self):
class TestGetTrainingChatTemplate:
def test_new_chat_template_is_prefix_preserving(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
assert is_chat_template_prefix_preserving(tokenizer) is False
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for future models, get_training_chat_template will be called with any chat template, not only non-prefix-preserving ones

tokenizer.chat_template = get_training_chat_template(tokenizer)
assert is_chat_template_prefix_preserving(tokenizer) is True

Expand Down Expand Up @@ -392,6 +391,40 @@ def test_behavior_unchanged_generation_prompt_with_enable_thinking_false(self, t
)
assert before == after

def test_assistant_masks(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
messages = [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
]
chat_template = get_training_chat_template(tokenizer)
result = tokenizer.apply_chat_template(
messages, chat_template=chat_template, return_assistant_tokens_mask=True, return_dict=True
)
masks = result["assistant_masks"]
assert 1 in masks
# The first tokens (user turn) should not be masked
assert masks[0] == 0
# The last tokens (assistant turn ending with <|im_end|>) should be masked
assert masks[-1] == 1

def test_assistant_masks_multi_turn(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
messages = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
{"role": "user", "content": "Bye"},
{"role": "assistant", "content": "Goodbye!"},
]
chat_template = get_training_chat_template(tokenizer)
result = tokenizer.apply_chat_template(
messages, chat_template=chat_template, return_assistant_tokens_mask=True, return_dict=True
)
masks = result["assistant_masks"]
# Should have two masked regions (two assistant turns): 0→1, 1→0, 0→1
transitions = sum(1 for i in range(1, len(masks)) if masks[i] != masks[i - 1])
assert transitions == 3


@pytest.mark.parametrize(
"tokenizer_name",
Expand Down
506 changes: 13 additions & 493 deletions trl/chat_template_utils.py

Large diffs are not rendered by default.

54 changes: 54 additions & 0 deletions trl/chat_templates/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Chat Templates

Jinja2 chat templates stored here serve two purposes:

1. **Identity comparison**: detecting which model is being used (by comparing `tokenizer.chat_template` against known templates) to add the appropriate response schema (`add_response_schema`) or swap in a training template (`get_training_chat_template`).
2. **Training patches**: modified templates that fix training-specific issues (prefix-preservation for GRPO, `{% generation %}` markers for SFT assistant-only loss).

**Why prefix-preserving?** The GRPO tool call loop extracts tool response formatting tokens by comparing tokenizations with and without tool messages appended (`_get_tool_suffix_ids`). This requires the chat template to be *prefix-preserving*: appending messages must not change how earlier messages are rendered.

**Why generation-tagged?** SFT with `assistant_only_loss=True` requires the chat template to include `{% generation %}` / `{% endgeneration %}` markers around assistant output, so `return_assistant_tokens_mask=True` can produce correct masks. Most model templates don't include these markers natively.

## Original templates

Used for identity comparison only.

### `qwen3.jinja`

Original Qwen3 chat template.

### `qwen3_5_2b_and_below.jinja` / `qwen3_5_4b_and_above.jinja`

Original Qwen3.5 chat templates.

## Training templates

Patched templates that fix training-specific issues. Swapped in at init when tools are enabled (GRPO) or when `assistant_only_loss=True` (SFT).

### `qwen3_training.jinja`

Patched Qwen3 template. Diff vs `qwen3.jinja`:

Require both `<think>` and `</think>` to be present before parsing, to avoid incorrect splitting when the model generates only one tag:

```diff
- {%- if '</think>' in content %}
+ {%- if '<think>' in content and '</think>' in content %}
```

Always include the thinking block regardless of message position. The original conditionally omits it based on `loop.last`, which changes the assistant rendering when a tool message is appended — breaking prefix-preservation:

```diff
- {%- if loop.index0 > ns.last_query_index %}
- {%- if loop.last or (not loop.last and reasoning_content) %}
- {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
- {%- else %}
- {{- '<|im_start|>' + message.role + '\n' + content }}
- {%- endif %}
- {%- else %}
- {{- '<|im_start|>' + message.role + '\n' + content }}
- {%- endif %}
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
```

Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss.
89 changes: 89 additions & 0 deletions trl/chat_templates/qwen3.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- endif %}
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- endfor %}
{%- for message in messages %}
{%- if message.content is string %}
{%- set content = message.content %}
{%- else %}
{%- set content = '' %}
{%- endif %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{%- set reasoning_content = '' %}
{%- if message.reasoning_content is string %}
{%- set reasoning_content = message.reasoning_content %}
{%- else %}
{%- if '</think>' in content %}
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
{%- if loop.index0 > ns.last_query_index %}
{%- if loop.last or (not loop.last and reasoning_content) %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if (loop.first and content) or (not loop.first) %}
{{- '\n' }}
{%- endif %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{%- if tool_call.arguments is string %}
{{- tool_call.arguments }}
{%- else %}
{{- tool_call.arguments | tojson }}
{%- endif %}
{{- '}\n</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- content }}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- if enable_thinking is defined and enable_thinking is false %}
{{- '<think>\n\n</think>\n\n' }}
{%- endif %}
{%- endif %}
154 changes: 154 additions & 0 deletions trl/chat_templates/qwen3_5_2b_and_below.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
{%- set image_count = namespace(value=0) %}
{%- set video_count = namespace(value=0) %}
{%- macro render_content(content, do_vision_count, is_system_content=false) %}
{%- if content is string %}
{{- content }}
{%- elif content is iterable and content is not mapping %}
{%- for item in content %}
{%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
{%- if is_system_content %}
{{- raise_exception('System message cannot contain images.') }}
{%- endif %}
{%- if do_vision_count %}
{%- set image_count.value = image_count.value + 1 %}
{%- endif %}
{%- if add_vision_id %}
{{- 'Picture ' ~ image_count.value ~ ': ' }}
{%- endif %}
{{- '<|vision_start|><|image_pad|><|vision_end|>' }}
{%- elif 'video' in item or item.type == 'video' %}
{%- if is_system_content %}
{{- raise_exception('System message cannot contain videos.') }}
{%- endif %}
{%- if do_vision_count %}
{%- set video_count.value = video_count.value + 1 %}
{%- endif %}
{%- if add_vision_id %}
{{- 'Video ' ~ video_count.value ~ ': ' }}
{%- endif %}
{{- '<|vision_start|><|video_pad|><|vision_end|>' }}
{%- elif 'text' in item %}
{{- item.text }}
{%- else %}
{{- raise_exception('Unexpected item type in content.') }}
{%- endif %}
{%- endfor %}
{%- elif content is none or content is undefined %}
{{- '' }}
{%- else %}
{{- raise_exception('Unexpected content type.') }}
{%- endif %}
{%- endmacro %}
{%- if not messages %}
{{- raise_exception('No messages provided.') }}
{%- endif %}
{%- if tools and tools is iterable and tools is not mapping %}
{{- '<|im_start|>system\n' }}
{{- "# Tools\n\nYou have access to the following functions:\n\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>" }}
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
{%- if messages[0].role == 'system' %}
{%- set content = render_content(messages[0].content, false, true)|trim %}
{%- if content %}
{{- '\n\n' + content }}
{%- endif %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- else %}
{%- if messages[0].role == 'system' %}
{%- set content = render_content(messages[0].content, false, true)|trim %}
{{- '<|im_start|>system\n' + content + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if ns.multi_step_tool and message.role == "user" %}
{%- set content = render_content(message.content, false)|trim %}
{%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if ns.multi_step_tool %}
{{- raise_exception('No user query found in messages.') }}
{%- endif %}
{%- for message in messages %}
{%- set content = render_content(message.content, true)|trim %}
{%- if message.role == "system" %}
{%- if not loop.first %}
{{- raise_exception('System message must be at the beginning.') }}
{%- endif %}
{%- elif message.role == "user" %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{%- set reasoning_content = '' %}
{%- if message.reasoning_content is string %}
{%- set reasoning_content = message.reasoning_content %}
{%- else %}
{%- if '</think>' in content %}
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
{%- set reasoning_content = reasoning_content|trim %}
{%- if loop.index0 > ns.last_query_index %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{%- if loop.first %}
{%- if content|trim %}
{{- '\n\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
{%- else %}
{{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
{%- endif %}
{%- else %}
{{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
{%- endif %}
{%- if tool_call.arguments is defined %}
{%- for args_name, args_value in tool_call.arguments|items %}
{{- '<parameter=' + args_name + '>\n' }}
{%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
{{- args_value }}
{{- '\n</parameter>\n' }}
{%- endfor %}
{%- endif %}
{{- '</function>\n</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.previtem and loop.previtem.role != "tool" %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- content }}
{{- '\n</tool_response>' }}
{%- if not loop.last and loop.nextitem.role != "tool" %}
{{- '<|im_end|>\n' }}
{%- elif loop.last %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- else %}
{{- raise_exception('Unexpected message role.') }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- if enable_thinking is defined and enable_thinking is true %}
{{- '<think>\n' }}
{%- else %}
{{- '<think>\n\n</think>\n\n' }}
{%- endif %}
{%- endif %}
Loading
Loading