diff --git a/docs/source/sft_trainer.md b/docs/source/sft_trainer.md
index b99dd9f2e5..089bfc595f 100644
--- a/docs/source/sft_trainer.md
+++ b/docs/source/sft_trainer.md
@@ -169,7 +169,7 @@ training_args = SFTConfig(assistant_only_loss=True)

> [!WARNING]
-> This functionality is only available for chat templates that support returning the assistant tokens mask via the `{% generation %}` and `{% endgeneration %}` keywords. For an example of such a template, see [HugggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja#L76-L82).
+> This functionality requires the chat template to include `{% generation %}` and `{% endgeneration %}` keywords. For known model families (e.g. Qwen3), TRL automatically patches the template when `assistant_only_loss=True`. For other models, check that your chat template includes these keywords — see [HuggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja#L76-L82) for an example.
### Train on completion only
diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py
index 4c4193d8b9..3242a673b4 100644
--- a/tests/test_chat_template_utils.py
+++ b/tests/test_chat_template_utils.py
@@ -315,7 +315,6 @@ def test_non_prefix_preserving_template(self):
class TestGetTrainingChatTemplate:
def test_new_chat_template_is_prefix_preserving(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
- assert is_chat_template_prefix_preserving(tokenizer) is False
tokenizer.chat_template = get_training_chat_template(tokenizer)
assert is_chat_template_prefix_preserving(tokenizer) is True
@@ -461,6 +460,40 @@ def test_behavior_unchanged_generation_prompt_with_enable_thinking_false(self, t
)
assert before == after
+ def test_assistant_masks(self, tokenizer_name):
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+ messages = [
+ {"role": "user", "content": "What color is the sky?"},
+ {"role": "assistant", "content": "It is blue."},
+ ]
+ chat_template = get_training_chat_template(tokenizer)
+ result = tokenizer.apply_chat_template(
+ messages, chat_template=chat_template, return_assistant_tokens_mask=True, return_dict=True
+ )
+ masks = result["assistant_masks"]
+ assert 1 in masks
+ # The first tokens (user turn) should not be masked
+ assert masks[0] == 0
+ # The last tokens (assistant turn ending with <|im_end|>) should be masked
+ assert masks[-1] == 1
+
+ def test_assistant_masks_multi_turn(self, tokenizer_name):
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+ messages = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello!"},
+ {"role": "user", "content": "Bye"},
+ {"role": "assistant", "content": "Goodbye!"},
+ ]
+ chat_template = get_training_chat_template(tokenizer)
+ result = tokenizer.apply_chat_template(
+ messages, chat_template=chat_template, return_assistant_tokens_mask=True, return_dict=True
+ )
+ masks = result["assistant_masks"]
+ # Should have two masked regions (two assistant turns): 0→1, 1→0, 0→1
+ transitions = sum(1 for i in range(1, len(masks)) if masks[i] != masks[i - 1])
+ assert transitions == 3
+
@pytest.mark.parametrize(
"tokenizer_name",
diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py
index 441de1f176..67f059ce02 100644
--- a/trl/chat_template_utils.py
+++ b/trl/chat_template_utils.py
@@ -370,10 +370,11 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool:
def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None:
r"""
- Get a prefix-preserving chat template for training, if needed.
+ Get a training-compatible chat template, if needed.
- If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3
- supported). Otherwise, returns `None`.
+ Returns a patched chat template that is prefix-preserving and includes `{%% generation %%}` / `{%% endgeneration
+ %%}` markers for assistant-only loss masking. Returns `None` if the tokenizer's template already satisfies both
+ requirements. Currently Qwen3 is supported.
Args:
tokenizer (`PreTrainedTokenizer`):
@@ -419,14 +420,15 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None:
```
"""
# First check if patching is needed
- if is_chat_template_prefix_preserving(tokenizer):
+ if is_chat_template_prefix_preserving(tokenizer) and "{% generation %}" in tokenizer.chat_template:
return None # No patching needed
if tokenizer.chat_template == qwen3_chat_template:
return qwen3_training_chat_template
else:
raise ValueError(
- "The tokenizer's chat template is not prefix-preserving and patching is not supported for this template. "
+ "The tokenizer's chat template is not training-compatible (missing prefix-preservation or "
+ "`{% generation %}` markers) and patching is not supported for this template. "
"Please manually modify the tokenizer's chat template for training."
)
diff --git a/trl/chat_templates/README.md b/trl/chat_templates/README.md
index 4a7ac33e8c..4a2cba97ae 100644
--- a/trl/chat_templates/README.md
+++ b/trl/chat_templates/README.md
@@ -3,10 +3,12 @@
Jinja2 chat templates stored here serve two purposes:
1. **Identity comparison**: detecting which model is being used (by comparing `tokenizer.chat_template` against known templates) to add the appropriate response schema (`add_response_schema`) or swap in a training template (`get_training_chat_template`).
-2. **Training patches**: modified templates that fix prefix-preservation issues for the GRPO tool call loop.
+2. **Training patches**: modified templates that fix training-specific issues (prefix-preservation for GRPO, `{% generation %}` markers for SFT assistant-only loss).
**Why prefix-preserving?** The GRPO tool call loop extracts tool response formatting tokens by comparing tokenizations with and without tool messages appended (`_get_tool_suffix_ids`). This requires the chat template to be *prefix-preserving*: appending messages must not change how earlier messages are rendered.
+**Why generation-tagged?** SFT with `assistant_only_loss=True` requires the chat template to include `{% generation %}` / `{% endgeneration %}` markers around assistant output, so `return_assistant_tokens_mask=True` can produce correct masks. Most model templates don't include these markers natively.
+
## Original templates
Used for identity comparison only.
@@ -25,11 +27,11 @@ Original Qwen3.5 chat templates.
## Training templates
-Patched templates that fix prefix-preservation issues. Swapped in at init when tools are enabled.
+Patched templates that fix training-specific issues. Swapped in at init when tools are enabled (GRPO) or when `assistant_only_loss=True` (SFT).
### `qwen3_training.jinja`
-Patched Qwen3 template that always includes thinking blocks, making it prefix-preserving. Diff vs `qwen3.jinja`:
+Patched Qwen3 template. Diff vs `qwen3.jinja`:
Require both `` and `` to be present before parsing, to avoid incorrect splitting when the model generates only one tag:
@@ -52,3 +54,5 @@ Always include the thinking block regardless of message position. The original c
- {%- endif %}
+ {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }}
```
+
+Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss.
diff --git a/trl/chat_templates/qwen3_training.jinja b/trl/chat_templates/qwen3_training.jinja
index 978e58e3eb..decc48dc6b 100644
--- a/trl/chat_templates/qwen3_training.jinja
+++ b/trl/chat_templates/qwen3_training.jinja
@@ -4,6 +4,8 @@
Always check for both tags to avoid edge cases where the model generates only one tag.
- Removed the loop.index0 > ns.last_query_index conditional; always include thinking block.
This makes the template prefix-preserving for the [user, assistant] → [user, assistant, tool] transition.
+ - Added {% generation %} / {% endgeneration %} around assistant message output to support
+ assistant-only loss masking in SFT training.
-#}
{%- if tools %}
{{- '<|im_start|>system\n' }}
@@ -47,7 +49,9 @@
{%- set content = content.split('')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
- {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }}
+ {{- '<|im_start|>' + message.role + '\n' }}
+ {%- generation %}
+ {{- '\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if (loop.first and content) or (not loop.first) %}
@@ -68,6 +72,7 @@
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
+ {%- endgeneration %}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py
index 5837939519..4fd11312fd 100644
--- a/trl/experimental/async_grpo/async_rollout_worker.py
+++ b/trl/experimental/async_grpo/async_rollout_worker.py
@@ -28,7 +28,12 @@
from datasets import Dataset
from transformers import AutoTokenizer
-from trl.chat_template_utils import add_response_schema, get_training_chat_template, parse_response
+from trl.chat_template_utils import (
+ add_response_schema,
+ get_training_chat_template,
+ is_chat_template_prefix_preserving,
+ parse_response,
+)
from trl.import_utils import is_vllm_available
from trl.trainer.utils import print_prompt_completions_sample
@@ -162,7 +167,12 @@ def __init__(
self.num_completions_to_print = num_completions_to_print
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer = add_response_schema(self.tokenizer)
- self.chat_template = get_training_chat_template(self.tokenizer)
+ # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template
+ # isn't, we replace it at initialization with a training-safe, prefix-preserving template.
+ if self.tools and not is_chat_template_prefix_preserving(self.tokenizer):
+ self.chat_template = get_training_chat_template(self.tokenizer)
+ else:
+ self.chat_template = None
self._groups_to_score: asyncio.Queue[RolloutGroup | None] = asyncio.Queue(maxsize=16)
self._total_completion_tokens = 0
diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py
index 1b1d6a1e3b..93ccde655c 100644
--- a/trl/trainer/grpo_trainer.py
+++ b/trl/trainer/grpo_trainer.py
@@ -59,6 +59,7 @@
from ..chat_template_utils import (
add_response_schema,
get_training_chat_template,
+ is_chat_template_prefix_preserving,
parse_response,
supports_tool_calling,
)
@@ -525,7 +526,7 @@ def __init__(
processing_class = add_response_schema(processing_class)
# In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template
# isn't, we replace it at initialization with a training-safe, prefix-preserving template.
- if self.tools:
+ if self.tools and not is_chat_template_prefix_preserving(processing_class):
self.chat_template = get_training_chat_template(processing_class)
else:
self.chat_template = None
diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py
index 1807520dc7..fe0b3fbe10 100644
--- a/trl/trainer/sft_trainer.py
+++ b/trl/trainer/sft_trainer.py
@@ -43,7 +43,7 @@
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available
-from ..chat_template_utils import clone_chat_template
+from ..chat_template_utils import clone_chat_template, get_training_chat_template
from ..data_utils import (
apply_chat_template,
is_conversational,
@@ -922,6 +922,13 @@ def __init__(
"supported for conversational datasets."
)
+ # When assistant_only_loss is enabled, swap in a training chat template with {% generation %} markers
+ # if the current template doesn't already have them.
+ if args.assistant_only_loss and "{% generation %}" not in processing_class.chat_template:
+ self.chat_template = get_training_chat_template(processing_class)
+ else:
+ self.chat_template = None
+
# Dataset
if self.padding_free and not args.packing and args.max_length is not None and not self._is_vision_dataset:
raise ValueError(
@@ -1038,7 +1045,9 @@ def _tokenize(
if isinstance(input, list): # conversational: list of message dicts
if self._is_vlm:
input = prepare_multimodal_messages(input)
- result = processing_class.apply_chat_template(input, tokenize=True, return_dict=True, **kwargs)
+ result = processing_class.apply_chat_template(
+ input, tokenize=True, return_dict=True, chat_template=self.chat_template, **kwargs
+ )
else: # non-conversational: plain text string
result = processing_class(text=input)
# VLMs emit a batch dimension even for single examples; unwrap it