Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
4b3aa51
Narrow prefix-preserving check to the actual requirement
qgallouedec Apr 5, 2026
0894910
Merge branch 'main' into narrow-prefix-preserving-check
qgallouedec Apr 5, 2026
730070b
Update chat template examples to use multiplication function calls
qgallouedec Apr 5, 2026
4622d77
style
qgallouedec Apr 5, 2026
08d4c51
Move chat templates from inline strings to `.jinja` files
qgallouedec Apr 5, 2026
276559d
tools in dummy
qgallouedec Apr 5, 2026
673c35d
Add chat template files to MANIFEST.in
qgallouedec Apr 5, 2026
604c476
Enhance chat template handling to include tool call formatting in mes…
qgallouedec Apr 5, 2026
83a7ef6
align grpo and async
qgallouedec Apr 5, 2026
0f28384
Merge branch 'main' into chat-templates-files
qgallouedec Apr 6, 2026
e5d7cdf
revert no content
qgallouedec Apr 6, 2026
a618809
docstyle ignore
qgallouedec Apr 6, 2026
a0b81b1
Merge branch 'main' into chat-templates-files
qgallouedec Apr 6, 2026
67ab0af
Merge branch 'main' into chat-templates-files
qgallouedec Apr 6, 2026
b18e39e
Add GPT-OSS tool calling support
qgallouedec Apr 6, 2026
71ce5a0
fix gpt oss
qgallouedec Apr 6, 2026
8f1ad1e
Update tool suffix ID retrieval to use actual tool names for GPT-OSS …
qgallouedec Apr 6, 2026
9b9771d
style
qgallouedec Apr 6, 2026
b3f4481
align async
qgallouedec Apr 6, 2026
76a0f66
Merge branch 'main' into gpt-oss-tool-calling
qgallouedec Apr 7, 2026
0890038
Merge branch 'main' into gpt-oss-tool-calling
qgallouedec Apr 7, 2026
3253602
style
qgallouedec Apr 7, 2026
b95dbec
Merge branch 'main' into gpt-oss-tool-calling
qgallouedec Apr 7, 2026
ec81a1e
Merge branch 'main' into gpt-oss-tool-calling
qgallouedec Apr 8, 2026
450b9ef
Merge branch 'main' into gpt-oss-tool-calling
qgallouedec Apr 9, 2026
392dece
Apply suggestions from code review
qgallouedec Apr 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ include CONTRIBUTING.md
include README.md
include trl/accelerate_configs/*.yaml
include trl/templates/*.md
include trl/chat_templates/*.jinja
include trl/chat_templates/*.md
include trl/skills/**/*.md
recursive-exclude * __pycache__
prune tests
1 change: 1 addition & 0 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ The returned images are automatically injected into the conversation and passed
Tested with:

- [**Gemma4**](https://huggingface.co/collections/google/gemma-4) — e.g., `google/gemma-4-E2B-it`
- [**GPT-OSS**](https://huggingface.co/collections/openai/gpt-oss) — e.g., `openai/gpt-4.1-nano`
- [**Qwen3**](https://huggingface.co/collections/Qwen/qwen3) — e.g., `Qwen/Qwen3-0.6B`
- [**Qwen3.5**](https://huggingface.co/collections/Qwen/qwen35) — e.g., `Qwen/Qwen3.5-2B`

Expand Down
50 changes: 36 additions & 14 deletions tests/test_chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def test_clone_with_sequence_classification_model(self):
@pytest.mark.parametrize(
"tokenizer_name",
[
pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"),
pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"),
pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"),
],
Expand Down Expand Up @@ -146,13 +147,25 @@ def test_add_response_schema(self, tokenizer_name):
class TestIsChatTemplatePrefixPreserving:
def test_prefix_preserving_template(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification")
# docstyle-ignore
tokenizer.chat_template = textwrap.dedent(r"""
{%- for message in messages %}

{%- if message.role == 'user' %}
{{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }}
{%- elif message.role == 'assistant' %}
{{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }}
{{- '<|im_start|>assistant\n' + message.content }}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>' + tool_call.name + '</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == 'tool' %}
{{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }}
{%- endif %}

{%- endfor %}
Expand All @@ -164,8 +177,9 @@ def test_prefix_preserving_template(self):

def test_non_prefix_preserving_template(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification")
# The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part is
# only present for last assistant message, which makes it non-prefix-preserving.
# The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part (even
# empty) is only present for last assistant message, which makes it non-prefix-preserving: appending a tool
# message changes the earlier output.
# docstyle-ignore
tokenizer.chat_template = textwrap.dedent(r"""
{%- if messages[0].role == 'system' %}
Expand Down Expand Up @@ -202,7 +216,17 @@ def test_non_prefix_preserving_template(self):
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>' + tool_call.name + '</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
Expand All @@ -218,14 +242,6 @@ def test_non_prefix_preserving_template(self):
"tokenizer_name",
[
pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"),
pytest.param(
"trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration",
id="qwen35",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.0.0"),
reason="Qwen3.5 tokenizer requires transformers>=5.0.0",
),
),
],
)
class TestGetTrainingChatTemplate:
Expand Down Expand Up @@ -381,6 +397,7 @@ def test_behavior_unchanged_generation_prompt_with_enable_thinking_false(self, t
@pytest.mark.parametrize(
"tokenizer_name",
[
pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"),
pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"),
pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"),
pytest.param(
Expand Down Expand Up @@ -415,8 +432,11 @@ def test_parse_response(self, tokenizer_name):
assert parsed == messages[-1]

def test_parse_response_with_reasoning_content(self, tokenizer_name):
if tokenizer_name == "trl-internal-testing/tiny-Gemma4ForConditionalGeneration":
pytest.skip("Gemma4 doesn't support inline reasoning_content.")
if tokenizer_name in (
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
"trl-internal-testing/tiny-GptOssForCausalLM",
):
pytest.skip("This model doesn't support inline reasoning_content.")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "response_schema", None) is None:
tokenizer = add_response_schema(tokenizer)
Expand Down Expand Up @@ -475,7 +495,7 @@ def test_parse_response_tool_call_without_arguments(self, tokenizer_name):
tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}]
messages = [
{"role": "user", "content": "Ping the service."},
{"role": "assistant", "tool_calls": tool_calls},
{"role": "assistant", "content": "", "tool_calls": tool_calls},
]
prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
text = tokenizer.apply_chat_template(messages).input_ids
Expand All @@ -484,6 +504,8 @@ def test_parse_response_tool_call_without_arguments(self, tokenizer_name):
assert parsed == {"role": "assistant", "content": "", "tool_calls": tool_calls}

def test_parse_response_multiple_tool_calls(self, tokenizer_name):
if tokenizer_name == "trl-internal-testing/tiny-GptOssForCausalLM":
pytest.skip("GPT-OSS template only renders one tool call per assistant message.")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "response_schema", None) is None:
tokenizer = add_response_schema(tokenizer)
Expand Down
Loading
Loading