diff --git a/docs/source/sft_trainer.md b/docs/source/sft_trainer.md index b99dd9f2e5..089bfc595f 100644 --- a/docs/source/sft_trainer.md +++ b/docs/source/sft_trainer.md @@ -169,7 +169,7 @@ training_args = SFTConfig(assistant_only_loss=True) ![train_on_assistant](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/train_on_assistant.png) > [!WARNING] -> This functionality is only available for chat templates that support returning the assistant tokens mask via the `{% generation %}` and `{% endgeneration %}` keywords. For an example of such a template, see [HugggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja#L76-L82). +> This functionality requires the chat template to include `{% generation %}` and `{% endgeneration %}` keywords. For known model families (e.g. Qwen3), TRL automatically patches the template when `assistant_only_loss=True`. For other models, check that your chat template includes these keywords — see [HuggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja#L76-L82) for an example. ### Train on completion only diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 4c4193d8b9..3242a673b4 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -315,7 +315,6 @@ def test_non_prefix_preserving_template(self): class TestGetTrainingChatTemplate: def test_new_chat_template_is_prefix_preserving(self, tokenizer_name): tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - assert is_chat_template_prefix_preserving(tokenizer) is False tokenizer.chat_template = get_training_chat_template(tokenizer) assert is_chat_template_prefix_preserving(tokenizer) is True @@ -461,6 +460,40 @@ def test_behavior_unchanged_generation_prompt_with_enable_thinking_false(self, t ) assert before == after + def test_assistant_masks(self, tokenizer_name): + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + messages = [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ] + chat_template = get_training_chat_template(tokenizer) + result = tokenizer.apply_chat_template( + messages, chat_template=chat_template, return_assistant_tokens_mask=True, return_dict=True + ) + masks = result["assistant_masks"] + assert 1 in masks + # The first tokens (user turn) should not be masked + assert masks[0] == 0 + # The last tokens (assistant turn ending with <|im_end|>) should be masked + assert masks[-1] == 1 + + def test_assistant_masks_multi_turn(self, tokenizer_name): + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + messages = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + {"role": "user", "content": "Bye"}, + {"role": "assistant", "content": "Goodbye!"}, + ] + chat_template = get_training_chat_template(tokenizer) + result = tokenizer.apply_chat_template( + messages, chat_template=chat_template, return_assistant_tokens_mask=True, return_dict=True + ) + masks = result["assistant_masks"] + # Should have two masked regions (two assistant turns): 0→1, 1→0, 0→1 + transitions = sum(1 for i in range(1, len(masks)) if masks[i] != masks[i - 1]) + assert transitions == 3 + @pytest.mark.parametrize( "tokenizer_name", diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 441de1f176..67f059ce02 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -370,10 +370,11 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: r""" - Get a prefix-preserving chat template for training, if needed. + Get a training-compatible chat template, if needed. - If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 - supported). Otherwise, returns `None`. + Returns a patched chat template that is prefix-preserving and includes `{%% generation %%}` / `{%% endgeneration + %%}` markers for assistant-only loss masking. Returns `None` if the tokenizer's template already satisfies both + requirements. Currently Qwen3 is supported. Args: tokenizer (`PreTrainedTokenizer`): @@ -419,14 +420,15 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: ``` """ # First check if patching is needed - if is_chat_template_prefix_preserving(tokenizer): + if is_chat_template_prefix_preserving(tokenizer) and "{% generation %}" in tokenizer.chat_template: return None # No patching needed if tokenizer.chat_template == qwen3_chat_template: return qwen3_training_chat_template else: raise ValueError( - "The tokenizer's chat template is not prefix-preserving and patching is not supported for this template. " + "The tokenizer's chat template is not training-compatible (missing prefix-preservation or " + "`{% generation %}` markers) and patching is not supported for this template. " "Please manually modify the tokenizer's chat template for training." ) diff --git a/trl/chat_templates/README.md b/trl/chat_templates/README.md index 4a7ac33e8c..4a2cba97ae 100644 --- a/trl/chat_templates/README.md +++ b/trl/chat_templates/README.md @@ -3,10 +3,12 @@ Jinja2 chat templates stored here serve two purposes: 1. **Identity comparison**: detecting which model is being used (by comparing `tokenizer.chat_template` against known templates) to add the appropriate response schema (`add_response_schema`) or swap in a training template (`get_training_chat_template`). -2. **Training patches**: modified templates that fix prefix-preservation issues for the GRPO tool call loop. +2. **Training patches**: modified templates that fix training-specific issues (prefix-preservation for GRPO, `{% generation %}` markers for SFT assistant-only loss). **Why prefix-preserving?** The GRPO tool call loop extracts tool response formatting tokens by comparing tokenizations with and without tool messages appended (`_get_tool_suffix_ids`). This requires the chat template to be *prefix-preserving*: appending messages must not change how earlier messages are rendered. +**Why generation-tagged?** SFT with `assistant_only_loss=True` requires the chat template to include `{% generation %}` / `{% endgeneration %}` markers around assistant output, so `return_assistant_tokens_mask=True` can produce correct masks. Most model templates don't include these markers natively. + ## Original templates Used for identity comparison only. @@ -25,11 +27,11 @@ Original Qwen3.5 chat templates. ## Training templates -Patched templates that fix prefix-preservation issues. Swapped in at init when tools are enabled. +Patched templates that fix training-specific issues. Swapped in at init when tools are enabled (GRPO) or when `assistant_only_loss=True` (SFT). ### `qwen3_training.jinja` -Patched Qwen3 template that always includes thinking blocks, making it prefix-preserving. Diff vs `qwen3.jinja`: +Patched Qwen3 template. Diff vs `qwen3.jinja`: Require both `` and `` to be present before parsing, to avoid incorrect splitting when the model generates only one tag: @@ -52,3 +54,5 @@ Always include the thinking block regardless of message position. The original c - {%- endif %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} ``` + +Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. diff --git a/trl/chat_templates/qwen3_training.jinja b/trl/chat_templates/qwen3_training.jinja index 978e58e3eb..decc48dc6b 100644 --- a/trl/chat_templates/qwen3_training.jinja +++ b/trl/chat_templates/qwen3_training.jinja @@ -4,6 +4,8 @@ Always check for both tags to avoid edge cases where the model generates only one tag. - Removed the loop.index0 > ns.last_query_index conditional; always include thinking block. This makes the template prefix-preserving for the [user, assistant] → [user, assistant, tool] transition. + - Added {% generation %} / {% endgeneration %} around assistant message output to support + assistant-only loss masking in SFT training. -#} {%- if tools %} {{- '<|im_start|>system\n' }} @@ -47,7 +49,9 @@ {%- set content = content.split('')[-1].lstrip('\n') %} {%- endif %} {%- endif %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {{- '<|im_start|>' + message.role + '\n' }} + {%- generation %} + {{- '\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} {%- if message.tool_calls %} {%- for tool_call in message.tool_calls %} {%- if (loop.first and content) or (not loop.first) %} @@ -68,6 +72,7 @@ {%- endfor %} {%- endif %} {{- '<|im_end|>\n' }} + {%- endgeneration %} {%- elif message.role == "tool" %} {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} {{- '<|im_start|>user' }} diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 5837939519..4fd11312fd 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -28,7 +28,12 @@ from datasets import Dataset from transformers import AutoTokenizer -from trl.chat_template_utils import add_response_schema, get_training_chat_template, parse_response +from trl.chat_template_utils import ( + add_response_schema, + get_training_chat_template, + is_chat_template_prefix_preserving, + parse_response, +) from trl.import_utils import is_vllm_available from trl.trainer.utils import print_prompt_completions_sample @@ -162,7 +167,12 @@ def __init__( self.num_completions_to_print = num_completions_to_print self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer = add_response_schema(self.tokenizer) - self.chat_template = get_training_chat_template(self.tokenizer) + # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template + # isn't, we replace it at initialization with a training-safe, prefix-preserving template. + if self.tools and not is_chat_template_prefix_preserving(self.tokenizer): + self.chat_template = get_training_chat_template(self.tokenizer) + else: + self.chat_template = None self._groups_to_score: asyncio.Queue[RolloutGroup | None] = asyncio.Queue(maxsize=16) self._total_completion_tokens = 0 diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 1b1d6a1e3b..93ccde655c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -59,6 +59,7 @@ from ..chat_template_utils import ( add_response_schema, get_training_chat_template, + is_chat_template_prefix_preserving, parse_response, supports_tool_calling, ) @@ -525,7 +526,7 @@ def __init__( processing_class = add_response_schema(processing_class) # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template # isn't, we replace it at initialization with a training-safe, prefix-preserving template. - if self.tools: + if self.tools and not is_chat_template_prefix_preserving(processing_class): self.chat_template = get_training_chat_template(processing_class) else: self.chat_template = None diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 1807520dc7..fe0b3fbe10 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -43,7 +43,7 @@ from transformers.trainer_utils import EvalPrediction from transformers.utils import is_peft_available -from ..chat_template_utils import clone_chat_template +from ..chat_template_utils import clone_chat_template, get_training_chat_template from ..data_utils import ( apply_chat_template, is_conversational, @@ -922,6 +922,13 @@ def __init__( "supported for conversational datasets." ) + # When assistant_only_loss is enabled, swap in a training chat template with {% generation %} markers + # if the current template doesn't already have them. + if args.assistant_only_loss and "{% generation %}" not in processing_class.chat_template: + self.chat_template = get_training_chat_template(processing_class) + else: + self.chat_template = None + # Dataset if self.padding_free and not args.packing and args.max_length is not None and not self._is_vision_dataset: raise ValueError( @@ -1038,7 +1045,9 @@ def _tokenize( if isinstance(input, list): # conversational: list of message dicts if self._is_vlm: input = prepare_multimodal_messages(input) - result = processing_class.apply_chat_template(input, tokenize=True, return_dict=True, **kwargs) + result = processing_class.apply_chat_template( + input, tokenize=True, return_dict=True, chat_template=self.chat_template, **kwargs + ) else: # non-conversational: plain text string result = processing_class(text=input) # VLMs emit a batch dimension even for single examples; unwrap it