Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ include CONTRIBUTING.md
include README.md
include trl/accelerate_configs/*.yaml
include trl/templates/*.md
include trl/chat_templates/*.jinja
include trl/chat_templates/*.md
include trl/skills/**/*.md
recursive-exclude * __pycache__
prune tests
39 changes: 27 additions & 12 deletions tests/test_chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,23 @@ def test_prefix_preserving_template(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification")
tokenizer.chat_template = textwrap.dedent(r"""
{%- for message in messages %}
{%- set content = message.content if message.content is string else '' %}

{%- if message.role == 'user' %}
{{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }}
{{- '<|im_start|>user\n' + content + '<|im_end|>\n' }}
{%- elif message.role == 'assistant' %}
{{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }}
{{- '<|im_start|>assistant\n' + content }}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>' + tool_call.name + '</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == 'tool' %}
{{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }}
{%- endif %}

{%- endfor %}
Expand All @@ -164,8 +176,9 @@ def test_prefix_preserving_template(self):

def test_non_prefix_preserving_template(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification")
# The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part is
# only present for last assistant message, which makes it non-prefix-preserving.
# The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part (even
# empty) is only present for last assistant message, which makes it non-prefix-preserving: appending a tool
# message changes the earlier output.
# docstyle-ignore
tokenizer.chat_template = textwrap.dedent(r"""
{%- if messages[0].role == 'system' %}
Expand Down Expand Up @@ -202,7 +215,17 @@ def test_non_prefix_preserving_template(self):
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>' + tool_call.name + '</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
Expand All @@ -218,14 +241,6 @@ def test_non_prefix_preserving_template(self):
"tokenizer_name",
[
pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"),
pytest.param(
"trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration",
id="qwen35",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.0.0"),
reason="Qwen3.5 tokenizer requires transformers>=5.0.0",
),
),
],
)
class TestGetTrainingChatTemplate:
Expand Down
Loading
Loading