Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions tests/test_chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,25 @@ def test_add_response_schema(self, tokenizer_name):
class TestIsChatTemplatePrefixPreserving:
def test_prefix_preserving_template(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification")
# docstyle-ignore
tokenizer.chat_template = textwrap.dedent(r"""
{%- for message in messages %}

{%- if message.role == 'user' %}
{{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }}
{%- elif message.role == 'assistant' %}
{{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }}
{{- '<|im_start|>assistant\n' + message.content }}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>' + tool_call.name + '</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == 'tool' %}
{{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }}
{%- endif %}

{%- endfor %}
Expand All @@ -164,8 +176,9 @@ def test_prefix_preserving_template(self):

def test_non_prefix_preserving_template(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification")
# The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part is
# only present for last assistant message, which makes it non-prefix-preserving.
# The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part (even
# empty) is only present for last assistant message, which makes it non-prefix-preserving: appending a tool
# message changes the earlier output.
# docstyle-ignore
tokenizer.chat_template = textwrap.dedent(r"""
{%- if messages[0].role == 'system' %}
Expand Down Expand Up @@ -202,7 +215,17 @@ def test_non_prefix_preserving_template(self):
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>' + tool_call.name + '</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
Expand All @@ -218,14 +241,6 @@ def test_non_prefix_preserving_template(self):
"tokenizer_name",
[
pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"),
pytest.param(
"trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration",
id="qwen35",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.0.0"),
reason="Qwen3.5 tokenizer requires transformers>=5.0.0",
),
),
],
)
class TestGetTrainingChatTemplate:
Expand Down Expand Up @@ -475,7 +490,7 @@ def test_parse_response_tool_call_without_arguments(self, tokenizer_name):
tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}]
messages = [
{"role": "user", "content": "Ping the service."},
{"role": "assistant", "tool_calls": tool_calls},
{"role": "assistant", "content": "", "tool_calls": tool_calls},
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure of this? Some lines above, we have: https://github.com/huggingface/transformers/blob/e1b80de84d3c5da35669b2834ef017eeaf620f93/src/transformers/processing_utils.py#L1789-L1790

                if not isinstance(message.get("content"), list):
                    continue
Suggested change
{"role": "assistant", "content": "", "tool_calls": tool_calls},
{"role": "assistant", "tool_calls": tool_calls},

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strange, I remember that it failed without the content, let's try again

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok I get it, wrong line:
https://github.com/huggingface/transformers/blob/e1b80de84d3c5da35669b2834ef017eeaf620f93/src/transformers/processing_utils.py#L1807

without the content: "", you'd get

        if tokenize:
            batch_images, batch_videos = [], []
            batch_audios = []
            for conversation in conversations:
                images, videos = [], []
                for message in conversation:
>                   visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
                                                      ^^^^^^^^^^^^^^^^^^
E                   KeyError: 'content'

../miniconda3/envs/trl/lib/python3.13/site-packages/transformers/processing_utils.py:1807: KeyError
--------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------
The fast path is not available because one of the required library is not installed. Falling back to torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and https://github.com/Dao-AILab/causal-conv1d
Loading weights: 100%|██████████| 60/60 [00:00<00:00, 3877.51it/s]
Loading weights: 100%|██████████| 27/27 [00:00<00:00, 12929.13it/s]
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 248046, 'pad_token_id': 248044}.
  0%|          | 0/51 [00:00<?, ?it/s]Kwargs passed to `processor.__call__` have to be in `processor_kwargs` dict, not in `**kwargs`
========================================================================= short test summary info ==========================================================================
FAILED tests/test_grpo_trainer.py::TestGRPOTrainer::test_training_with_tools_multimodal_response - KeyError: 'content'

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reconsideration, I think it's likely a transformers bug.

huggingface/transformers#45290

I suggest keeping this patch for now so that it's not blocking for users

]
prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
text = tokenizer.apply_chat_template(messages).input_ids
Expand Down
85 changes: 33 additions & 52 deletions trl/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,10 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool:
"""
Check whether the chat template preserves prefixes when applied.

A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. This
property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by comparing
tokenizations with and without tool messages appended.

Args:
tokenizer (`PreTrainedTokenizer`):
Tokenizer instance to check.
Expand All @@ -638,24 +642,22 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool:
`bool`:
`True` if the chat template preserves prefixes, `False` otherwise.
"""
# Use the same dummy messages as _get_tool_suffix_ids to test the exact property it relies on.
dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}]
messages1 = [
{"role": "user", "content": "What color is the sky?"},
{"role": "user", "content": "dummy"},
{"role": "assistant", "content": "", "tool_calls": dummy_tool_calls},
]
messages2 = [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
]
messages3 = [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
{"role": "user", "content": "And at night?"},
{"role": "user", "content": "dummy"},
{"role": "assistant", "content": "", "tool_calls": dummy_tool_calls},
{"role": "tool", "name": "dummy", "content": "dummy"},
]

text1 = tokenizer.apply_chat_template(messages1, tokenize=False, add_generation_prompt=True)
text2 = tokenizer.apply_chat_template(messages2, tokenize=False)
text3 = tokenizer.apply_chat_template(messages3, tokenize=False)
text1 = tokenizer.apply_chat_template(messages1, tokenize=False)
text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)

return text2.startswith(text1) and text3.startswith(text2)
return text2.startswith(text1)


# Modifications:
Expand Down Expand Up @@ -749,33 +751,12 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool:
{%- endif %}"""


# Modifications:
# - {%- if '</think>' in content %}
# + {%- if '<think>' in content and '</think>' in content %}
# Always check for both tags to avoid edge cases where the model generates only one tag, which would otherwise be parsed incorrectly
# - {{- '<|im_start|>' + message.role + '\n' + content }}
# + {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }}
# Always include thinking block during training. It's important to have a prefix-preserving template.
def _patch_qwen3_5_training_template(template: str) -> str:
return template.replace(
"{%- if '</think>' in content %}",
"{%- if '<think>' in content and '</think>' in content %}",
).replace(
"{{- '<|im_start|>' + message.role + '\\n' + content }}",
"{{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content + '\\n</think>\\n\\n' + content }}",
)


qwen3_5_training_chat_template_2b_and_below = _patch_qwen3_5_training_template(qwen3_5_chat_template_2b_and_below)
qwen3_5_training_chat_template_4b_and_above = _patch_qwen3_5_training_template(qwen3_5_chat_template_4b_and_above)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mention Qwen3.5 doesn't need patching anymore. Does this depend on the transformers version?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, it's really about the chat template:

from transformers import AutoTokenizer, AutoProcessor

dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}]

messages1 = [
    {"role": "user", "content": "dummy"},
    {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls},
]
messages2 = messages1 + [
    {"role": "tool", "name": "dummy", "content": "dummy"},
]


model_id = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
text1 = tokenizer.apply_chat_template(messages1, tokenize=False)
text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
print(f"\n{'='*60}")
print(f"{model_id}")
print(f"Prefix-preserving: {text2.startswith(text1)}")
print(repr(text1))
print(repr(text2))

model_id = "Qwen/Qwen3.5-0.8B"
tokenizer = AutoProcessor.from_pretrained(model_id)
text1 = tokenizer.apply_chat_template(messages1, tokenize=False)
text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
print(f"\n{'='*60}")
print(f"{model_id}")
print(f"Prefix-preserving: {text2.startswith(text1)}")
print(repr(text1))
print(repr(text2))
============================================================
Qwen/Qwen3-0.6B
Prefix-preserving: False
'<|im_start|>user\ndummy<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n{"name": "dummy", "arguments": {}}\n</tool_call><|im_end|>\n'
'<|im_start|>user\ndummy<|im_end|>\n<|im_start|>assistant\n<tool_call>\n{"name": "dummy", "arguments": {}}\n</tool_call><|im_end|>\n<|im_start|>user\n<tool_response>\ndummy\n</tool_response><|im_end|>\n<|im_start|>assistant\n'

============================================================
Qwen/Qwen3.5-0.8B
Prefix-preserving: True
'<|im_start|>user\ndummy<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n<function=dummy>\n</function>\n</tool_call><|im_end|>\n'
'<|im_start|>user\ndummy<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n<function=dummy>\n</function>\n</tool_call><|im_end|>\n<|im_start|>user\n<tool_response>\ndummy\n</tool_response><|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n'



def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None:
r"""
Get a prefix-preserving chat template for training, if needed.

If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 and
Qwen3.5 supported). Otherwise, returns `None`.
If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3
supported). Otherwise, returns `None`.

Args:
tokenizer (`PreTrainedTokenizer`):
Expand All @@ -793,27 +774,31 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None:

>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
>>> messages1 = [
... {"role": "user", "content": "What color is the sky?"},
... {"role": "assistant", "content": "It is blue."},
... {"role": "user", "content": "What is 2 * 3?"},
... {
... "role": "assistant",
... "content": "",
... "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}],
... },
... ]
>>> messages2 = [
... {"role": "user", "content": "What color is the sky?"},
... {"role": "assistant", "content": "It is blue."},
... {"role": "user", "content": "And at night?"},
>>> messages2 = messages1 + [
... {"role": "tool", "name": "multiply", "content": "6"},
... ]
>>> tokenizer.apply_chat_template(messages1, tokenize=False)
'<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nIt is blue.<|im_end|>\n'
'<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n</tool_call><|im_end|>\n'

>>> tokenizer.apply_chat_template(messages2, tokenize=False)
'<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n'
>>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
'<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n<tool_call>\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n</tool_call><|im_end|>\n<|im_start|>user\n<tool_response>\n6\n</tool_response><|im_end|>\n<|im_start|>assistant\n'

>>> # ^ think tags missing
>>> # ^ think tags missing
>>> chat_template = get_training_chat_template(tokenizer)
>>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template)
'<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nIt is blue.<|im_end|>\n'
'<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n</tool_call><|im_end|>\n'

>>> tokenizer.apply_chat_template(messages2, tokenize=False, chat_template=chat_template)
'<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n'
>>> tokenizer.apply_chat_template(
... messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template
... )
'<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n</tool_call><|im_end|>\n<|im_start|>user\n<tool_response>\n6\n</tool_response><|im_end|>\n<|im_start|>assistant\n'
```
"""
# First check if patching is needed
Expand All @@ -822,10 +807,6 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None:

if tokenizer.chat_template == qwen3_chat_template:
return qwen3_training_chat_template
if tokenizer.chat_template == qwen3_5_chat_template_2b_and_below:
return qwen3_5_training_chat_template_2b_and_below
if tokenizer.chat_template == qwen3_5_chat_template_4b_and_above:
return qwen3_5_training_chat_template_4b_and_above
else:
raise ValueError(
"The tokenizer's chat template is not prefix-preserving and patching is not supported for this template. "
Expand Down
24 changes: 14 additions & 10 deletions trl/experimental/async_grpo/async_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,31 +555,35 @@ async def _generate_one(

def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int]:
"""Get token IDs for tool result formatting by using a minimal dummy conversation."""
dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}]
dummy_messages = [
{"role": "user", "content": ""},
{"role": "assistant", "content": ""},
{"role": "user", "content": "dummy"},
{"role": "assistant", "content": "", "tool_calls": dummy_tool_calls},
]
prefix_ids = self.tokenizer.apply_chat_template(
dummy_messages,
return_dict=False,
tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[]
add_generation_prompt=False,
tokenize=True,
chat_template=self.chat_template,
return_dict=False,
**self.chat_template_kwargs,
)
full_ids = self.tokenizer.apply_chat_template(
dummy_messages + tool_messages,
return_dict=False,
chat_template=self.chat_template,
add_generation_prompt=True,
tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[]
tokenize=True,
chat_template=self.chat_template,
return_dict=False,
**self.chat_template_kwargs,
)
Comment on lines 556 to 584
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to align GRPO and async GRPO


# Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block.
# When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to
# EOS (not EOS + newline).
last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id)
prefix_ids = prefix_ids[: last_eos_idx + 1]
# EOS (not EOS + newline). Templates that don't use EOS as end-of-turn (e.g. Gemma uses
# <turn|>) skip this trimming.
eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id]
if eos_positions:
prefix_ids = prefix_ids[: eos_positions[-1] + 1]
Comment on lines -580 to +592
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if full_ids[: len(prefix_ids)] != prefix_ids:
raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.")
Expand Down
6 changes: 5 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,7 +1404,11 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):

def _get_tool_suffix_ids(self, tool_messages):
"""Get token IDs for tool result formatting by using a minimal dummy conversation."""
dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}]
dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}]
dummy_messages = [
{"role": "user", "content": "dummy"},
{"role": "assistant", "content": "", "tool_calls": dummy_tool_calls},
]
if self._is_vlm:
dummy_messages = prepare_multimodal_messages(dummy_messages)
prefix_ids = self.processing_class.apply_chat_template(
Expand Down
Loading