Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ include CONTRIBUTING.md
include README.md
include trl/accelerate_configs/*.yaml
include trl/templates/*.md
include trl/chat_templates/*.jinja
include trl/chat_templates/*.md
include trl/skills/**/*.md
recursive-exclude * __pycache__
prune tests
17 changes: 7 additions & 10 deletions tests/test_chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def test_prefix_preserving_template(self):
{{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }}
{%- elif message.role == 'assistant' %}
{{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }}
{%- elif message.role == 'tool' %}
{{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }}
{%- endif %}

{%- endfor %}
Expand All @@ -164,8 +166,9 @@ def test_prefix_preserving_template(self):

def test_non_prefix_preserving_template(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification")
# The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part is
# only present for last assistant message, which makes it non-prefix-preserving.
# The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part (even
# empty) is only present for last assistant message, which makes it non-prefix-preserving: appending a tool
# message changes the earlier output.
# docstyle-ignore
tokenizer.chat_template = textwrap.dedent(r"""
{%- if messages[0].role == 'system' %}
Expand Down Expand Up @@ -203,6 +206,8 @@ def test_non_prefix_preserving_template(self):
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
Expand All @@ -218,14 +223,6 @@ def test_non_prefix_preserving_template(self):
"tokenizer_name",
[
pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"),
pytest.param(
"trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration",
id="qwen35",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.0.0"),
reason="Qwen3.5 tokenizer requires transformers>=5.0.0",
),
),
],
)
class TestGetTrainingChatTemplate:
Expand Down
590 changes: 45 additions & 545 deletions trl/chat_template_utils.py

Large diffs are not rendered by default.

50 changes: 50 additions & 0 deletions trl/chat_templates/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Chat Templates

Jinja2 chat templates stored here serve two purposes:

1. **Identity comparison**: detecting which model is being used (by comparing `tokenizer.chat_template` against known templates) to add the appropriate response schema (`add_response_schema`) or swap in a training template (`get_training_chat_template`).
2. **Training patches**: modified templates that fix prefix-preservation issues for the GRPO tool call loop.

**Why prefix-preserving?** The GRPO tool call loop extracts tool response formatting tokens by comparing tokenizations with and without tool messages appended (`_get_tool_suffix_ids`). This requires the chat template to be *prefix-preserving*: appending messages must not change how earlier messages are rendered.

## Original templates

Used for identity comparison only.

### `qwen3.jinja`

Original Qwen3 chat template.

### `qwen3_5_2b_and_below.jinja` / `qwen3_5_4b_and_above.jinja`

Original Qwen3.5 chat templates.

## Training templates

Patched templates that fix prefix-preservation issues. Swapped in at init when tools are enabled.

### `qwen3_training.jinja`

Patched Qwen3 template that always includes thinking blocks, making it prefix-preserving. Diff vs `qwen3.jinja`:

Require both `<think>` and `</think>` to be present before parsing, to avoid incorrect splitting when the model generates only one tag:

```diff
- {%- if '</think>' in content %}
+ {%- if '<think>' in content and '</think>' in content %}
```

Always include the thinking block regardless of message position. The original conditionally omits it based on `loop.last`, which changes the assistant rendering when a tool message is appended — breaking prefix-preservation:

```diff
- {%- if loop.index0 > ns.last_query_index %}
- {%- if loop.last or (not loop.last and reasoning_content) %}
- {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
- {%- else %}
- {{- '<|im_start|>' + message.role + '\n' + content }}
- {%- endif %}
- {%- else %}
- {{- '<|im_start|>' + message.role + '\n' + content }}
- {%- endif %}
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
```
89 changes: 89 additions & 0 deletions trl/chat_templates/qwen3.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- endif %}
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- endfor %}
{%- for message in messages %}
{%- if message.content is string %}
{%- set content = message.content %}
{%- else %}
{%- set content = '' %}
{%- endif %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{%- set reasoning_content = '' %}
{%- if message.reasoning_content is string %}
{%- set reasoning_content = message.reasoning_content %}
{%- else %}
{%- if '</think>' in content %}
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
{%- if loop.index0 > ns.last_query_index %}
{%- if loop.last or (not loop.last and reasoning_content) %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if (loop.first and content) or (not loop.first) %}
{{- '\n' }}
{%- endif %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{%- if tool_call.arguments is string %}
{{- tool_call.arguments }}
{%- else %}
{{- tool_call.arguments | tojson }}
{%- endif %}
{{- '}\n</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- content }}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- if enable_thinking is defined and enable_thinking is false %}
{{- '<think>\n\n</think>\n\n' }}
{%- endif %}
{%- endif %}
154 changes: 154 additions & 0 deletions trl/chat_templates/qwen3_5_2b_and_below.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
{%- set image_count = namespace(value=0) %}
{%- set video_count = namespace(value=0) %}
{%- macro render_content(content, do_vision_count, is_system_content=false) %}
{%- if content is string %}
{{- content }}
{%- elif content is iterable and content is not mapping %}
{%- for item in content %}
{%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
{%- if is_system_content %}
{{- raise_exception('System message cannot contain images.') }}
{%- endif %}
{%- if do_vision_count %}
{%- set image_count.value = image_count.value + 1 %}
{%- endif %}
{%- if add_vision_id %}
{{- 'Picture ' ~ image_count.value ~ ': ' }}
{%- endif %}
{{- '<|vision_start|><|image_pad|><|vision_end|>' }}
{%- elif 'video' in item or item.type == 'video' %}
{%- if is_system_content %}
{{- raise_exception('System message cannot contain videos.') }}
{%- endif %}
{%- if do_vision_count %}
{%- set video_count.value = video_count.value + 1 %}
{%- endif %}
{%- if add_vision_id %}
{{- 'Video ' ~ video_count.value ~ ': ' }}
{%- endif %}
{{- '<|vision_start|><|video_pad|><|vision_end|>' }}
{%- elif 'text' in item %}
{{- item.text }}
{%- else %}
{{- raise_exception('Unexpected item type in content.') }}
{%- endif %}
{%- endfor %}
{%- elif content is none or content is undefined %}
{{- '' }}
{%- else %}
{{- raise_exception('Unexpected content type.') }}
{%- endif %}
{%- endmacro %}
{%- if not messages %}
{{- raise_exception('No messages provided.') }}
{%- endif %}
{%- if tools and tools is iterable and tools is not mapping %}
{{- '<|im_start|>system\n' }}
{{- "# Tools\n\nYou have access to the following functions:\n\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>" }}
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
{%- if messages[0].role == 'system' %}
{%- set content = render_content(messages[0].content, false, true)|trim %}
{%- if content %}
{{- '\n\n' + content }}
{%- endif %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- else %}
{%- if messages[0].role == 'system' %}
{%- set content = render_content(messages[0].content, false, true)|trim %}
{{- '<|im_start|>system\n' + content + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if ns.multi_step_tool and message.role == "user" %}
{%- set content = render_content(message.content, false)|trim %}
{%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if ns.multi_step_tool %}
{{- raise_exception('No user query found in messages.') }}
{%- endif %}
{%- for message in messages %}
{%- set content = render_content(message.content, true)|trim %}
{%- if message.role == "system" %}
{%- if not loop.first %}
{{- raise_exception('System message must be at the beginning.') }}
{%- endif %}
{%- elif message.role == "user" %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{%- set reasoning_content = '' %}
{%- if message.reasoning_content is string %}
{%- set reasoning_content = message.reasoning_content %}
{%- else %}
{%- if '</think>' in content %}
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
{%- set reasoning_content = reasoning_content|trim %}
{%- if loop.index0 > ns.last_query_index %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{%- if loop.first %}
{%- if content|trim %}
{{- '\n\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
{%- else %}
{{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
{%- endif %}
{%- else %}
{{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
{%- endif %}
{%- if tool_call.arguments is defined %}
{%- for args_name, args_value in tool_call.arguments|items %}
{{- '<parameter=' + args_name + '>\n' }}
{%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
{{- args_value }}
{{- '\n</parameter>\n' }}
{%- endfor %}
{%- endif %}
{{- '</function>\n</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.previtem and loop.previtem.role != "tool" %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- content }}
{{- '\n</tool_response>' }}
{%- if not loop.last and loop.nextitem.role != "tool" %}
{{- '<|im_end|>\n' }}
{%- elif loop.last %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- else %}
{{- raise_exception('Unexpected message role.') }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- if enable_thinking is defined and enable_thinking is true %}
{{- '<think>\n' }}
{%- else %}
{{- '<think>\n\n</think>\n\n' }}
{%- endif %}
{%- endif %}
Loading
Loading