Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4b3aa51
Narrow prefix-preserving check to the actual requirement
qgallouedec Apr 5, 2026
0894910
Merge branch 'main' into narrow-prefix-preserving-check
qgallouedec Apr 5, 2026
730070b
Update chat template examples to use multiplication function calls
qgallouedec Apr 5, 2026
4622d77
style
qgallouedec Apr 5, 2026
08d4c51
Move chat templates from inline strings to `.jinja` files
qgallouedec Apr 5, 2026
276559d
tools in dummy
qgallouedec Apr 5, 2026
673c35d
Add chat template files to MANIFEST.in
qgallouedec Apr 5, 2026
604c476
Enhance chat template handling to include tool call formatting in mes…
qgallouedec Apr 5, 2026
83a7ef6
align grpo and async
qgallouedec Apr 5, 2026
0f28384
Merge branch 'main' into chat-templates-files
qgallouedec Apr 6, 2026
e5d7cdf
revert no content
qgallouedec Apr 6, 2026
a618809
docstyle ignore
qgallouedec Apr 6, 2026
a0b81b1
Merge branch 'main' into chat-templates-files
qgallouedec Apr 6, 2026
67ab0af
Merge branch 'main' into chat-templates-files
qgallouedec Apr 6, 2026
63ec7d3
Merge branch 'main' into chat-templates-files
qgallouedec Apr 7, 2026
c838146
Merge branch 'main' into chat-templates-files
qgallouedec Apr 7, 2026
7b7f5d1
revert old modif
qgallouedec Apr 7, 2026
e37fe00
Merge branch 'main' into chat-templates-files
qgallouedec Apr 7, 2026
3c2fc8e
Add `{% generation %}` support to training chat templates
qgallouedec Apr 7, 2026
d887724
conditionally call this in grpo
qgallouedec Apr 7, 2026
9d4d57f
align
qgallouedec Apr 7, 2026
632e8b1
Merge branch 'main' into generation-tags
qgallouedec Apr 8, 2026
b33bc66
Refactor Qwen3 training template: Adjust generation block and output …
qgallouedec Apr 8, 2026
8820a29
Merge branch 'main' into generation-tags
qgallouedec Apr 9, 2026
31e640f
Merge branch 'main' into generation-tags
qgallouedec Apr 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/sft_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ training_args = SFTConfig(assistant_only_loss=True)
![train_on_assistant](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/train_on_assistant.png)

> [!WARNING]
> This functionality is only available for chat templates that support returning the assistant tokens mask via the `{% generation %}` and `{% endgeneration %}` keywords. For an example of such a template, see [HugggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja#L76-L82).
> This functionality requires the chat template to include `{% generation %}` and `{% endgeneration %}` keywords. For known model families (e.g. Qwen3), TRL automatically patches the template when `assistant_only_loss=True`. For other models, check that your chat template includes these keywords — see [HuggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja#L76-L82) for an example.

### Train on completion only

Expand Down
35 changes: 34 additions & 1 deletion tests/test_chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ def test_non_prefix_preserving_template(self):
class TestGetTrainingChatTemplate:
def test_new_chat_template_is_prefix_preserving(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
assert is_chat_template_prefix_preserving(tokenizer) is False
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for future models, get_training_chat_template will be called with any chat template, not only non-prefix-preserving ones

tokenizer.chat_template = get_training_chat_template(tokenizer)
assert is_chat_template_prefix_preserving(tokenizer) is True

Expand Down Expand Up @@ -392,6 +391,40 @@ def test_behavior_unchanged_generation_prompt_with_enable_thinking_false(self, t
)
assert before == after

def test_assistant_masks(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
messages = [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
]
chat_template = get_training_chat_template(tokenizer)
result = tokenizer.apply_chat_template(
messages, chat_template=chat_template, return_assistant_tokens_mask=True, return_dict=True
)
masks = result["assistant_masks"]
assert 1 in masks
# The first tokens (user turn) should not be masked
assert masks[0] == 0
# The last tokens (assistant turn ending with <|im_end|>) should be masked
assert masks[-1] == 1

def test_assistant_masks_multi_turn(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
messages = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
{"role": "user", "content": "Bye"},
{"role": "assistant", "content": "Goodbye!"},
]
chat_template = get_training_chat_template(tokenizer)
result = tokenizer.apply_chat_template(
messages, chat_template=chat_template, return_assistant_tokens_mask=True, return_dict=True
)
masks = result["assistant_masks"]
# Should have two masked regions (two assistant turns): 0→1, 1→0, 0→1
transitions = sum(1 for i in range(1, len(masks)) if masks[i] != masks[i - 1])
assert transitions == 3


@pytest.mark.parametrize(
"tokenizer_name",
Expand Down
12 changes: 7 additions & 5 deletions trl/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,11 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool:

def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None:
r"""
Get a prefix-preserving chat template for training, if needed.
Get a training-compatible chat template, if needed.

If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3
supported). Otherwise, returns `None`.
Returns a patched chat template that is prefix-preserving and includes `{%% generation %%}` / `{%% endgeneration
%%}` markers for assistant-only loss masking. Returns `None` if the tokenizer's template already satisfies both
requirements. Currently Qwen3 is supported.

Args:
tokenizer (`PreTrainedTokenizer`):
Expand Down Expand Up @@ -322,14 +323,15 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None:
```
"""
# First check if patching is needed
if is_chat_template_prefix_preserving(tokenizer):
if is_chat_template_prefix_preserving(tokenizer) and "{% generation %}" in tokenizer.chat_template:
return None # No patching needed

if tokenizer.chat_template == qwen3_chat_template:
return qwen3_training_chat_template
else:
raise ValueError(
"The tokenizer's chat template is not prefix-preserving and patching is not supported for this template. "
"The tokenizer's chat template is not training-compatible (missing prefix-preservation or "
"`{% generation %}` markers) and patching is not supported for this template. "
"Please manually modify the tokenizer's chat template for training."
)

Expand Down
10 changes: 7 additions & 3 deletions trl/chat_templates/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
Jinja2 chat templates stored here serve two purposes:

1. **Identity comparison**: detecting which model is being used (by comparing `tokenizer.chat_template` against known templates) to add the appropriate response schema (`add_response_schema`) or swap in a training template (`get_training_chat_template`).
2. **Training patches**: modified templates that fix prefix-preservation issues for the GRPO tool call loop.
2. **Training patches**: modified templates that fix training-specific issues (prefix-preservation for GRPO, `{% generation %}` markers for SFT assistant-only loss).

**Why prefix-preserving?** The GRPO tool call loop extracts tool response formatting tokens by comparing tokenizations with and without tool messages appended (`_get_tool_suffix_ids`). This requires the chat template to be *prefix-preserving*: appending messages must not change how earlier messages are rendered.

**Why generation-tagged?** SFT with `assistant_only_loss=True` requires the chat template to include `{% generation %}` / `{% endgeneration %}` markers around assistant output, so `return_assistant_tokens_mask=True` can produce correct masks. Most model templates don't include these markers natively.

## Original templates

Used for identity comparison only.
Expand All @@ -21,11 +23,11 @@ Original Qwen3.5 chat templates.

## Training templates

Patched templates that fix prefix-preservation issues. Swapped in at init when tools are enabled.
Patched templates that fix training-specific issues. Swapped in at init when tools are enabled (GRPO) or when `assistant_only_loss=True` (SFT).

### `qwen3_training.jinja`

Patched Qwen3 template that always includes thinking blocks, making it prefix-preserving. Diff vs `qwen3.jinja`:
Patched Qwen3 template. Diff vs `qwen3.jinja`:

Require both `<think>` and `</think>` to be present before parsing, to avoid incorrect splitting when the model generates only one tag:

Expand All @@ -48,3 +50,5 @@ Always include the thinking block regardless of message position. The original c
- {%- endif %}
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
```

Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss.
4 changes: 4 additions & 0 deletions trl/chat_templates/qwen3_training.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Always check for both tags to avoid edge cases where the model generates only one tag.
- Removed the loop.index0 > ns.last_query_index conditional; always include thinking block.
This makes the template prefix-preserving for the [user, assistant] → [user, assistant, tool] transition.
- Added {% generation %} / {% endgeneration %} around assistant message output to support
assistant-only loss masking in SFT training.
-#}
{%- if tools %}
{{- '<|im_start|>system\n' }}
Expand Down Expand Up @@ -47,6 +49,7 @@
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
{% generation %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
Expand All @@ -68,6 +71,7 @@
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- endgeneration %}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
Expand Down
14 changes: 12 additions & 2 deletions trl/experimental/async_grpo/async_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
from datasets import Dataset
from transformers import AutoTokenizer

from trl.chat_template_utils import add_response_schema, get_training_chat_template, parse_response
from trl.chat_template_utils import (
add_response_schema,
get_training_chat_template,
is_chat_template_prefix_preserving,
parse_response,
)
from trl.import_utils import is_vllm_available
from trl.trainer.utils import print_prompt_completions_sample

Expand Down Expand Up @@ -162,7 +167,12 @@ def __init__(
self.num_completions_to_print = num_completions_to_print
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer = add_response_schema(self.tokenizer)
self.chat_template = get_training_chat_template(self.tokenizer)
# In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template
# isn't, we replace it at initialization with a training-safe, prefix-preserving template.
if self.tools and not is_chat_template_prefix_preserving(self.tokenizer):
self.chat_template = get_training_chat_template(self.tokenizer)
else:
self.chat_template = None

self._groups_to_score: asyncio.Queue[RolloutGroup | None] = asyncio.Queue(maxsize=16)
self._total_completion_tokens = 0
Expand Down
9 changes: 7 additions & 2 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@
)
from transformers.utils import is_peft_available, is_rich_available

from ..chat_template_utils import add_response_schema, get_training_chat_template, parse_response
from ..chat_template_utils import (
add_response_schema,
get_training_chat_template,
is_chat_template_prefix_preserving,
parse_response,
)
from ..data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages
from ..extras.profiling import profiling_context, profiling_decorator
from ..generation.vllm_generation import VLLMGeneration
Expand Down Expand Up @@ -515,7 +520,7 @@ def __init__(
processing_class = add_response_schema(processing_class)
# In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template
# isn't, we replace it at initialization with a training-safe, prefix-preserving template.
if self.tools:
if self.tools and not is_chat_template_prefix_preserving(processing_class):
self.chat_template = get_training_chat_template(processing_class)
else:
self.chat_template = None
Expand Down
13 changes: 11 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available

from ..chat_template_utils import clone_chat_template
from ..chat_template_utils import clone_chat_template, get_training_chat_template
from ..data_utils import (
apply_chat_template,
is_conversational,
Expand Down Expand Up @@ -922,6 +922,13 @@ def __init__(
"supported for conversational datasets."
)

# When assistant_only_loss is enabled, swap in a training chat template with {% generation %} markers
# if the current template doesn't already have them.
if args.assistant_only_loss and "{% generation %}" not in processing_class.chat_template:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential crash when chat_template is None

Low Severity

The in check "{% generation %}" not in processing_class.chat_template will raise a TypeError if processing_class.chat_template is None. This can happen with tokenizers that have no chat template set. A guard like processing_class.chat_template and "{% generation %}" not in processing_class.chat_template would prevent the crash and provide a clearer path to the downstream ValueError from get_training_chat_template.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 9d4d57f. Configure here.

self.chat_template = get_training_chat_template(processing_class)
else:
self.chat_template = None

# Dataset
if self.padding_free and not args.packing and args.max_length is not None and not self._is_vision_dataset:
raise ValueError(
Expand Down Expand Up @@ -1038,7 +1045,9 @@ def _tokenize(
if isinstance(input, list): # conversational: list of message dicts
if self._is_vlm:
input = prepare_multimodal_messages(input)
result = processing_class.apply_chat_template(input, tokenize=True, return_dict=True, **kwargs)
result = processing_class.apply_chat_template(
input, tokenize=True, return_dict=True, chat_template=self.chat_template, **kwargs
)
else: # non-conversational: plain text string
result = processing_class(text=input)
# VLMs emit a batch dimension even for single examples; unwrap it
Expand Down
Loading