diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 9fa45e5f71..2cdd45b85f 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -146,13 +146,25 @@ def test_add_response_schema(self, tokenizer_name): class TestIsChatTemplatePrefixPreserving: def test_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + # docstyle-ignore tokenizer.chat_template = textwrap.dedent(r""" {%- for message in messages %} {%- if message.role == 'user' %} {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} {%- elif message.role == 'assistant' %} - {{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }} + {{- '<|im_start|>assistant\n' + message.content }} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '' + tool_call.name + '' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == 'tool' %} + {{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} @@ -164,8 +176,9 @@ def test_prefix_preserving_template(self): def test_non_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") - # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part is - # only present for last assistant message, which makes it non-prefix-preserving. + # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part (even + # empty) is only present for last assistant message, which makes it non-prefix-preserving: appending a tool + # message changes the earlier output. # docstyle-ignore tokenizer.chat_template = textwrap.dedent(r""" {%- if messages[0].role == 'system' %} @@ -202,7 +215,17 @@ def test_non_prefix_preserving_template(self): {%- else %} {{- '<|im_start|>' + message.role + '\n' + content }} {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '' + tool_call.name + '' }} + {%- endfor %} + {%- endif %} {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} {%- endif %} {%- endfor %} {%- if add_generation_prompt %} @@ -218,14 +241,6 @@ def test_non_prefix_preserving_template(self): "tokenizer_name", [ pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), - pytest.param( - "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", - id="qwen35", - marks=pytest.mark.skipif( - Version(transformers.__version__) < Version("5.0.0"), - reason="Qwen3.5 tokenizer requires transformers>=5.0.0", - ), - ), ], ) class TestGetTrainingChatTemplate: diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 86317a5467..4ea93b92ad 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -630,6 +630,10 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: """ Check whether the chat template preserves prefixes when applied. + A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. This + property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by comparing + tokenizations with and without tool messages appended. + Args: tokenizer (`PreTrainedTokenizer`): Tokenizer instance to check. @@ -638,24 +642,22 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: `bool`: `True` if the chat template preserves prefixes, `False` otherwise. """ + # Use the same dummy messages as _get_tool_suffix_ids to test the exact property it relies on. + dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] messages1 = [ - {"role": "user", "content": "What color is the sky?"}, + {"role": "user", "content": "dummy"}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, ] messages2 = [ - {"role": "user", "content": "What color is the sky?"}, - {"role": "assistant", "content": "It is blue."}, - ] - messages3 = [ - {"role": "user", "content": "What color is the sky?"}, - {"role": "assistant", "content": "It is blue."}, - {"role": "user", "content": "And at night?"}, + {"role": "user", "content": "dummy"}, + {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, + {"role": "tool", "name": "dummy", "content": "dummy"}, ] - text1 = tokenizer.apply_chat_template(messages1, tokenize=False, add_generation_prompt=True) - text2 = tokenizer.apply_chat_template(messages2, tokenize=False) - text3 = tokenizer.apply_chat_template(messages3, tokenize=False) + text1 = tokenizer.apply_chat_template(messages1, tokenize=False) + text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) - return text2.startswith(text1) and text3.startswith(text2) + return text2.startswith(text1) # Modifications: @@ -749,33 +751,12 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: {%- endif %}""" -# Modifications: -# - {%- if '' in content %} -# + {%- if '' in content and '' in content %} -# Always check for both tags to avoid edge cases where the model generates only one tag, which would otherwise be parsed incorrectly -# - {{- '<|im_start|>' + message.role + '\n' + content }} -# + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} -# Always include thinking block during training. It's important to have a prefix-preserving template. -def _patch_qwen3_5_training_template(template: str) -> str: - return template.replace( - "{%- if '' in content %}", - "{%- if '' in content and '' in content %}", - ).replace( - "{{- '<|im_start|>' + message.role + '\\n' + content }}", - "{{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content + '\\n\\n\\n' + content }}", - ) - - -qwen3_5_training_chat_template_2b_and_below = _patch_qwen3_5_training_template(qwen3_5_chat_template_2b_and_below) -qwen3_5_training_chat_template_4b_and_above = _patch_qwen3_5_training_template(qwen3_5_chat_template_4b_and_above) - - def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: r""" Get a prefix-preserving chat template for training, if needed. - If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 and - Qwen3.5 supported). Otherwise, returns `None`. + If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 + supported). Otherwise, returns `None`. Args: tokenizer (`PreTrainedTokenizer`): @@ -793,27 +774,31 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") >>> messages1 = [ - ... {"role": "user", "content": "What color is the sky?"}, - ... {"role": "assistant", "content": "It is blue."}, + ... {"role": "user", "content": "What is 2 * 3?"}, + ... { + ... "role": "assistant", + ... "content": "", + ... "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], + ... }, ... ] - >>> messages2 = [ - ... {"role": "user", "content": "What color is the sky?"}, - ... {"role": "assistant", "content": "It is blue."}, - ... {"role": "user", "content": "And at night?"}, + >>> messages2 = messages1 + [ + ... {"role": "tool", "name": "multiply", "content": "6"}, ... ] >>> tokenizer.apply_chat_template(messages1, tokenize=False) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n' - >>> tokenizer.apply_chat_template(messages2, tokenize=False) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' + >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n' - >>> # ^ think tags missing + >>> # ^ think tags missing >>> chat_template = get_training_chat_template(tokenizer) >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n' - >>> tokenizer.apply_chat_template(messages2, tokenize=False, chat_template=chat_template) - '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' + >>> tokenizer.apply_chat_template( + ... messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template + ... ) + '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n' ``` """ # First check if patching is needed @@ -822,10 +807,6 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: if tokenizer.chat_template == qwen3_chat_template: return qwen3_training_chat_template - if tokenizer.chat_template == qwen3_5_chat_template_2b_and_below: - return qwen3_5_training_chat_template_2b_and_below - if tokenizer.chat_template == qwen3_5_chat_template_4b_and_above: - return qwen3_5_training_chat_template_4b_and_above else: raise ValueError( "The tokenizer's chat template is not prefix-preserving and patching is not supported for this template. " diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 3d7350d5a7..df32627641 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -555,31 +555,41 @@ async def _generate_one( def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int]: """Get token IDs for tool result formatting by using a minimal dummy conversation.""" + dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] dummy_messages = [ - {"role": "user", "content": ""}, - {"role": "assistant", "content": ""}, + {"role": "user", "content": "dummy"}, + { + "role": "assistant", + # "content" is required here because VLM processors crash on tokenize=True without it + # (KeyError in processing_utils.py). See huggingface/transformers#45290. + "content": "", + "tool_calls": dummy_tool_calls, + }, ] prefix_ids = self.tokenizer.apply_chat_template( dummy_messages, - return_dict=False, - tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] + add_generation_prompt=False, + tokenize=True, chat_template=self.chat_template, + return_dict=False, **self.chat_template_kwargs, ) full_ids = self.tokenizer.apply_chat_template( dummy_messages + tool_messages, - return_dict=False, - chat_template=self.chat_template, add_generation_prompt=True, - tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] + tokenize=True, + chat_template=self.chat_template, + return_dict=False, **self.chat_template_kwargs, ) # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block. # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to - # EOS (not EOS + newline). - last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id) - prefix_ids = prefix_ids[: last_eos_idx + 1] + # EOS (not EOS + newline). Templates that don't use EOS as end-of-turn (e.g. Gemma uses + # ) skip this trimming. + eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id] + if eos_positions: + prefix_ids = prefix_ids[: eos_positions[-1] + 1] if full_ids[: len(prefix_ids)] != prefix_ids: raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index fd276b671f..e949181e20 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1404,7 +1404,17 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): def _get_tool_suffix_ids(self, tool_messages): """Get token IDs for tool result formatting by using a minimal dummy conversation.""" - dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}] + dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] + dummy_messages = [ + {"role": "user", "content": "dummy"}, + { + "role": "assistant", + # "content" is required here because VLM processors crash on tokenize=True without it + # (KeyError in processing_utils.py). See huggingface/transformers#45290. + "content": "", + "tool_calls": dummy_tool_calls, + }, + ] if self._is_vlm: dummy_messages = prepare_multimodal_messages(dummy_messages) prefix_ids = self.processing_class.apply_chat_template(