-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Narrow prefix-preserving check to the actual requirement #5458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
4b3aa51
0894910
730070b
4622d77
8a00354
cd8cbfc
8b35321
160d6a0
8dd341b
7e3ddd7
71bf73f
103d3c9
fffcb67
87131e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -146,13 +146,25 @@ def test_add_response_schema(self, tokenizer_name): | |||||
| class TestIsChatTemplatePrefixPreserving: | ||||||
| def test_prefix_preserving_template(self): | ||||||
| tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") | ||||||
| # docstyle-ignore | ||||||
| tokenizer.chat_template = textwrap.dedent(r""" | ||||||
| {%- for message in messages %} | ||||||
|
|
||||||
| {%- if message.role == 'user' %} | ||||||
| {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} | ||||||
| {%- elif message.role == 'assistant' %} | ||||||
| {{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }} | ||||||
| {{- '<|im_start|>assistant\n' + message.content }} | ||||||
| {%- if message.tool_calls %} | ||||||
| {%- for tool_call in message.tool_calls %} | ||||||
| {%- if tool_call.function %} | ||||||
| {%- set tool_call = tool_call.function %} | ||||||
| {%- endif %} | ||||||
| {{- '<tool_call>' + tool_call.name + '</tool_call>' }} | ||||||
| {%- endfor %} | ||||||
| {%- endif %} | ||||||
| {{- '<|im_end|>\n' }} | ||||||
| {%- elif message.role == 'tool' %} | ||||||
| {{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }} | ||||||
| {%- endif %} | ||||||
|
|
||||||
| {%- endfor %} | ||||||
|
|
@@ -164,8 +176,9 @@ def test_prefix_preserving_template(self): | |||||
|
|
||||||
| def test_non_prefix_preserving_template(self): | ||||||
| tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") | ||||||
| # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part is | ||||||
| # only present for last assistant message, which makes it non-prefix-preserving. | ||||||
| # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part (even | ||||||
| # empty) is only present for last assistant message, which makes it non-prefix-preserving: appending a tool | ||||||
| # message changes the earlier output. | ||||||
| # docstyle-ignore | ||||||
| tokenizer.chat_template = textwrap.dedent(r""" | ||||||
| {%- if messages[0].role == 'system' %} | ||||||
|
|
@@ -202,7 +215,17 @@ def test_non_prefix_preserving_template(self): | |||||
| {%- else %} | ||||||
| {{- '<|im_start|>' + message.role + '\n' + content }} | ||||||
| {%- endif %} | ||||||
| {%- if message.tool_calls %} | ||||||
| {%- for tool_call in message.tool_calls %} | ||||||
| {%- if tool_call.function %} | ||||||
| {%- set tool_call = tool_call.function %} | ||||||
| {%- endif %} | ||||||
| {{- '<tool_call>' + tool_call.name + '</tool_call>' }} | ||||||
| {%- endfor %} | ||||||
| {%- endif %} | ||||||
| {{- '<|im_end|>\n' }} | ||||||
| {%- elif message.role == "tool" %} | ||||||
| {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }} | ||||||
qgallouedec marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| {%- endif %} | ||||||
| {%- endfor %} | ||||||
| {%- if add_generation_prompt %} | ||||||
|
|
@@ -218,14 +241,6 @@ def test_non_prefix_preserving_template(self): | |||||
| "tokenizer_name", | ||||||
| [ | ||||||
| pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), | ||||||
| pytest.param( | ||||||
| "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", | ||||||
| id="qwen35", | ||||||
| marks=pytest.mark.skipif( | ||||||
| Version(transformers.__version__) < Version("5.0.0"), | ||||||
| reason="Qwen3.5 tokenizer requires transformers>=5.0.0", | ||||||
| ), | ||||||
| ), | ||||||
| ], | ||||||
| ) | ||||||
| class TestGetTrainingChatTemplate: | ||||||
|
|
@@ -475,7 +490,7 @@ def test_parse_response_tool_call_without_arguments(self, tokenizer_name): | |||||
| tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}] | ||||||
| messages = [ | ||||||
| {"role": "user", "content": "Ping the service."}, | ||||||
| {"role": "assistant", "tool_calls": tool_calls}, | ||||||
| {"role": "assistant", "content": "", "tool_calls": tool_calls}, | ||||||
|
||||||
| {"role": "assistant", "content": "", "tool_calls": tool_calls}, | |
| {"role": "assistant", "tool_calls": tool_calls}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
strange, I remember that it failed without the content, let's try again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ok I get it, wrong line:
https://github.com/huggingface/transformers/blob/e1b80de84d3c5da35669b2834ef017eeaf620f93/src/transformers/processing_utils.py#L1807
without the content: "", you'd get
if tokenize:
batch_images, batch_videos = [], []
batch_audios = []
for conversation in conversations:
images, videos = [], []
for message in conversation:
> visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
^^^^^^^^^^^^^^^^^^
E KeyError: 'content'
../miniconda3/envs/trl/lib/python3.13/site-packages/transformers/processing_utils.py:1807: KeyError
--------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------
The fast path is not available because one of the required library is not installed. Falling back to torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and https://github.com/Dao-AILab/causal-conv1d
Loading weights: 100%|██████████| 60/60 [00:00<00:00, 3877.51it/s]
Loading weights: 100%|██████████| 27/27 [00:00<00:00, 12929.13it/s]
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 248046, 'pad_token_id': 248044}.
0%| | 0/51 [00:00<?, ?it/s]Kwargs passed to `processor.__call__` have to be in `processor_kwargs` dict, not in `**kwargs`
========================================================================= short test summary info ==========================================================================
FAILED tests/test_grpo_trainer.py::TestGRPOTrainer::test_training_with_tools_multimodal_response - KeyError: 'content'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reconsideration, I think it's likely a transformers bug.
huggingface/transformers#45290
I suggest keeping this patch for now so that it's not blocking for users
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -630,6 +630,10 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: | |
| """ | ||
| Check whether the chat template preserves prefixes when applied. | ||
|
|
||
| A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. This | ||
| property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by comparing | ||
| tokenizations with and without tool messages appended. | ||
|
|
||
| Args: | ||
| tokenizer (`PreTrainedTokenizer`): | ||
| Tokenizer instance to check. | ||
|
|
@@ -638,24 +642,22 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: | |
| `bool`: | ||
| `True` if the chat template preserves prefixes, `False` otherwise. | ||
| """ | ||
| # Use the same dummy messages as _get_tool_suffix_ids to test the exact property it relies on. | ||
| dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] | ||
| messages1 = [ | ||
| {"role": "user", "content": "What color is the sky?"}, | ||
| {"role": "user", "content": "dummy"}, | ||
| {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, | ||
| ] | ||
| messages2 = [ | ||
| {"role": "user", "content": "What color is the sky?"}, | ||
| {"role": "assistant", "content": "It is blue."}, | ||
| ] | ||
| messages3 = [ | ||
| {"role": "user", "content": "What color is the sky?"}, | ||
| {"role": "assistant", "content": "It is blue."}, | ||
| {"role": "user", "content": "And at night?"}, | ||
| {"role": "user", "content": "dummy"}, | ||
| {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, | ||
| {"role": "tool", "name": "dummy", "content": "dummy"}, | ||
| ] | ||
|
|
||
| text1 = tokenizer.apply_chat_template(messages1, tokenize=False, add_generation_prompt=True) | ||
| text2 = tokenizer.apply_chat_template(messages2, tokenize=False) | ||
| text3 = tokenizer.apply_chat_template(messages3, tokenize=False) | ||
| text1 = tokenizer.apply_chat_template(messages1, tokenize=False) | ||
| text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) | ||
|
|
||
| return text2.startswith(text1) and text3.startswith(text2) | ||
| return text2.startswith(text1) | ||
|
|
||
|
|
||
| # Modifications: | ||
|
|
@@ -749,33 +751,12 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: | |
| {%- endif %}""" | ||
|
|
||
|
|
||
| # Modifications: | ||
| # - {%- if '</think>' in content %} | ||
| # + {%- if '<think>' in content and '</think>' in content %} | ||
| # Always check for both tags to avoid edge cases where the model generates only one tag, which would otherwise be parsed incorrectly | ||
| # - {{- '<|im_start|>' + message.role + '\n' + content }} | ||
| # + {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }} | ||
| # Always include thinking block during training. It's important to have a prefix-preserving template. | ||
| def _patch_qwen3_5_training_template(template: str) -> str: | ||
| return template.replace( | ||
| "{%- if '</think>' in content %}", | ||
| "{%- if '<think>' in content and '</think>' in content %}", | ||
| ).replace( | ||
| "{{- '<|im_start|>' + message.role + '\\n' + content }}", | ||
| "{{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content + '\\n</think>\\n\\n' + content }}", | ||
| ) | ||
|
|
||
|
|
||
| qwen3_5_training_chat_template_2b_and_below = _patch_qwen3_5_training_template(qwen3_5_chat_template_2b_and_below) | ||
| qwen3_5_training_chat_template_4b_and_above = _patch_qwen3_5_training_template(qwen3_5_chat_template_4b_and_above) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mention Qwen3.5 doesn't need patching anymore. Does this depend on the transformers version?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, it's really about the chat template: from transformers import AutoTokenizer, AutoProcessor
dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}]
messages1 = [
{"role": "user", "content": "dummy"},
{"role": "assistant", "content": "", "tool_calls": dummy_tool_calls},
]
messages2 = messages1 + [
{"role": "tool", "name": "dummy", "content": "dummy"},
]
model_id = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
text1 = tokenizer.apply_chat_template(messages1, tokenize=False)
text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
print(f"\n{'='*60}")
print(f"{model_id}")
print(f"Prefix-preserving: {text2.startswith(text1)}")
print(repr(text1))
print(repr(text2))
model_id = "Qwen/Qwen3.5-0.8B"
tokenizer = AutoProcessor.from_pretrained(model_id)
text1 = tokenizer.apply_chat_template(messages1, tokenize=False)
text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
print(f"\n{'='*60}")
print(f"{model_id}")
print(f"Prefix-preserving: {text2.startswith(text1)}")
print(repr(text1))
print(repr(text2)) |
||
|
|
||
|
|
||
| def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: | ||
| r""" | ||
| Get a prefix-preserving chat template for training, if needed. | ||
|
|
||
| If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 and | ||
| Qwen3.5 supported). Otherwise, returns `None`. | ||
| If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 | ||
| supported). Otherwise, returns `None`. | ||
|
|
||
| Args: | ||
| tokenizer (`PreTrainedTokenizer`): | ||
|
|
@@ -793,27 +774,31 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: | |
|
|
||
| >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") | ||
| >>> messages1 = [ | ||
| ... {"role": "user", "content": "What color is the sky?"}, | ||
| ... {"role": "assistant", "content": "It is blue."}, | ||
| ... {"role": "user", "content": "What is 2 * 3?"}, | ||
| ... { | ||
| ... "role": "assistant", | ||
| ... "content": "", | ||
| ... "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], | ||
| ... }, | ||
| ... ] | ||
| >>> messages2 = [ | ||
| ... {"role": "user", "content": "What color is the sky?"}, | ||
| ... {"role": "assistant", "content": "It is blue."}, | ||
| ... {"role": "user", "content": "And at night?"}, | ||
| >>> messages2 = messages1 + [ | ||
| ... {"role": "tool", "name": "multiply", "content": "6"}, | ||
| ... ] | ||
| >>> tokenizer.apply_chat_template(messages1, tokenize=False) | ||
| '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nIt is blue.<|im_end|>\n' | ||
| '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n</tool_call><|im_end|>\n' | ||
|
|
||
| >>> tokenizer.apply_chat_template(messages2, tokenize=False) | ||
| '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' | ||
| >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) | ||
| '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n<tool_call>\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n</tool_call><|im_end|>\n<|im_start|>user\n<tool_response>\n6\n</tool_response><|im_end|>\n<|im_start|>assistant\n' | ||
|
|
||
| >>> # ^ think tags missing | ||
| >>> # ^ think tags missing | ||
| >>> chat_template = get_training_chat_template(tokenizer) | ||
| >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template) | ||
| '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nIt is blue.<|im_end|>\n' | ||
| '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n</tool_call><|im_end|>\n' | ||
|
|
||
| >>> tokenizer.apply_chat_template(messages2, tokenize=False, chat_template=chat_template) | ||
| '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' | ||
| >>> tokenizer.apply_chat_template( | ||
| ... messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template | ||
| ... ) | ||
| '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n</tool_call><|im_end|>\n<|im_start|>user\n<tool_response>\n6\n</tool_response><|im_end|>\n<|im_start|>assistant\n' | ||
| ``` | ||
| """ | ||
| # First check if patching is needed | ||
|
|
@@ -822,10 +807,6 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: | |
|
|
||
| if tokenizer.chat_template == qwen3_chat_template: | ||
| return qwen3_training_chat_template | ||
| if tokenizer.chat_template == qwen3_5_chat_template_2b_and_below: | ||
| return qwen3_5_training_chat_template_2b_and_below | ||
| if tokenizer.chat_template == qwen3_5_chat_template_4b_and_above: | ||
| return qwen3_5_training_chat_template_4b_and_above | ||
| else: | ||
| raise ValueError( | ||
| "The tokenizer's chat template is not prefix-preserving and patching is not supported for this template. " | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -555,31 +555,35 @@ async def _generate_one( | |
|
|
||
| def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int]: | ||
| """Get token IDs for tool result formatting by using a minimal dummy conversation.""" | ||
| dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] | ||
| dummy_messages = [ | ||
| {"role": "user", "content": ""}, | ||
| {"role": "assistant", "content": ""}, | ||
| {"role": "user", "content": "dummy"}, | ||
| {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, | ||
| ] | ||
| prefix_ids = self.tokenizer.apply_chat_template( | ||
| dummy_messages, | ||
| return_dict=False, | ||
| tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] | ||
| add_generation_prompt=False, | ||
| tokenize=True, | ||
| chat_template=self.chat_template, | ||
| return_dict=False, | ||
| **self.chat_template_kwargs, | ||
| ) | ||
| full_ids = self.tokenizer.apply_chat_template( | ||
| dummy_messages + tool_messages, | ||
| return_dict=False, | ||
| chat_template=self.chat_template, | ||
| add_generation_prompt=True, | ||
| tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] | ||
| tokenize=True, | ||
| chat_template=self.chat_template, | ||
| return_dict=False, | ||
| **self.chat_template_kwargs, | ||
| ) | ||
|
Comment on lines
556
to
584
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just to align GRPO and async GRPO |
||
|
|
||
| # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block. | ||
| # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to | ||
| # EOS (not EOS + newline). | ||
| last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id) | ||
| prefix_ids = prefix_ids[: last_eos_idx + 1] | ||
| # EOS (not EOS + newline). Templates that don't use EOS as end-of-turn (e.g. Gemma uses | ||
| # <turn|>) skip this trimming. | ||
| eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id] | ||
| if eos_positions: | ||
| prefix_ids = prefix_ids[: eos_positions[-1] + 1] | ||
|
Comment on lines
-580
to
+592
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| if full_ids[: len(prefix_ids)] != prefix_ids: | ||
| raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.