diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 7a85ce772f..e65b16fc94 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -744,6 +744,7 @@ Tested with: - **GLM-4-MoE** ([4.5](https://huggingface.co/collections/zai-org/glm-45), [4.6](https://huggingface.co/collections/zai-org/glm-46) or [4.7](https://huggingface.co/collections/zai-org/glm-47)) — e.g., `zai-org/GLM-4.7` - [**GPT-OSS**](https://huggingface.co/collections/openai/gpt-oss) — e.g., `openai/gpt-oss-20b` - [**Qwen3**](https://huggingface.co/collections/Qwen/qwen3) — e.g., `Qwen/Qwen3-0.6B` +- [**Qwen3-VL**](https://huggingface.co/collections/Qwen/qwen3-vl) — e.g., `Qwen/Qwen3-VL-2B-Instruct` - [**Qwen3.5**](https://huggingface.co/collections/Qwen/qwen35) — e.g., `Qwen/Qwen3.5-2B` > [!TIP] diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 22888bbba8..081f07478a 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -116,7 +116,8 @@ def test_clone_with_sequence_classification_model(self): [ pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"), pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), - pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), + pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"), + pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"), pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"), ], ) @@ -215,7 +216,7 @@ def test_deepseek_tool_calling(self, model_id): class TestIsChatTemplatePrefixPreserving: def test_prefix_preserving_template(self): - tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForCausalLM") # docstyle-ignore tokenizer.chat_template = textwrap.dedent(r""" {%- for message in messages %} @@ -245,7 +246,7 @@ def test_prefix_preserving_template(self): assert is_chat_template_prefix_preserving(tokenizer) is True def test_non_prefix_preserving_template(self): - tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForCausalLM") # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part (even # empty) is only present for last assistant message, which makes it non-prefix-preserving: appending a tool # message changes the earlier output. @@ -312,7 +313,7 @@ def test_non_prefix_preserving_template(self): [ pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3", id="llama3"), - pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), + pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"), ], ) class TestGetTrainingChatTemplate: @@ -503,7 +504,8 @@ def test_assistant_masks_multi_turn(self, tokenizer_name): [ pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"), pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), - pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), + pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"), + pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"), pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"), pytest.param( "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", @@ -540,8 +542,10 @@ def test_parse_response_with_reasoning_content(self, tokenizer_name): if tokenizer_name in ( "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", "trl-internal-testing/tiny-GptOssForCausalLM", + "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", ): - pytest.skip("This model doesn't support inline reasoning_content.") + pytest.skip("This tokenizer doesn't support inline reasoning_content.") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) if getattr(tokenizer, "response_schema", None) is None: tokenizer = add_response_schema(tokenizer) @@ -629,7 +633,7 @@ def test_parse_response_multiple_tool_calls(self, tokenizer_name): assert parsed == messages[-1] def test_parse_response_malformed_tool_call(self, tokenizer_name): - if tokenizer_name != "trl-internal-testing/tiny-Qwen3MoeForSequenceClassification": + if tokenizer_name != "trl-internal-testing/tiny-Qwen3MoeForCausalLM": pytest.skip("For simplicity, we only test the malformed tool call case on one tokenizer.") tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) if getattr(tokenizer, "response_schema", None) is None: diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 3bafca8c6d..8925ae05e9 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -273,6 +273,8 @@ def clone_chat_template( qwen3_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3.jinja").read_text() +qwen3_vl_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3_vl.jinja").read_text() + qwen3_5_chat_template_2b_and_below = (_CHAT_TEMPLATES_DIR / "qwen3_5_2b_and_below.jinja").read_text() qwen3_5_chat_template_4b_and_above = (_CHAT_TEMPLATES_DIR / "qwen3_5_4b_and_above.jinja").read_text() @@ -313,7 +315,7 @@ def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: if tokenizer.chat_template == gptoss_chat_template: tokenizer.response_schema = gptoss_schema return tokenizer - if tokenizer.chat_template == qwen3_chat_template: + if tokenizer.chat_template in [qwen3_chat_template, qwen3_vl_chat_template]: tokenizer.response_schema = qwen3_schema return tokenizer if tokenizer.chat_template in [qwen3_5_chat_template_2b_and_below, qwen3_5_chat_template_4b_and_above]: diff --git a/trl/chat_templates/README.md b/trl/chat_templates/README.md index 4c624b5f35..e4f6f66ac0 100644 --- a/trl/chat_templates/README.md +++ b/trl/chat_templates/README.md @@ -25,6 +25,10 @@ Original GPT-OSS chat template. Original Qwen3 chat template. +### `qwen3_vl.jinja` + +Original Qwen3-VL chat template. Unlike text-only Qwen3, this template is already prefix-preserving (no conditional thinking blocks), so no training patch is needed. + ### `qwen3_5_2b_and_below.jinja` / `qwen3_5_4b_and_above.jinja` Original Qwen3.5 chat templates. diff --git a/trl/chat_templates/qwen3_vl.jinja b/trl/chat_templates/qwen3_vl.jinja new file mode 100644 index 0000000000..124386803f --- /dev/null +++ b/trl/chat_templates/qwen3_vl.jinja @@ -0,0 +1,120 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {%- if messages[0].content is string %} + {{- messages[0].content }} + {%- else %} + {%- for content in messages[0].content %} + {%- if 'text' in content %} + {{- content.text }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].content is string %} + {{- messages[0].content }} + {%- else %} + {%- for content in messages[0].content %} + {%- if 'text' in content %} + {{- content.text }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set image_count = namespace(value=0) %} +{%- set video_count = namespace(value=0) %} +{%- for message in messages %} + {%- if message.role == "user" %} + {{- '<|im_start|>' + message.role + '\n' }} + {%- if message.content is string %} + {{- message.content }} + {%- else %} + {%- for content in message.content %} + {%- if content.type == 'image' or 'image' in content or 'image_url' in content %} + {%- set image_count.value = image_count.value + 1 %} + {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%} + <|vision_start|><|image_pad|><|vision_end|> + {%- elif content.type == 'video' or 'video' in content %} + {%- set video_count.value = video_count.value + 1 %} + {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%} + <|vision_start|><|video_pad|><|vision_end|> + {%- elif 'text' in content %} + {{- content.text }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' }} + {%- if message.content is string %} + {{- message.content }} + {%- else %} + {%- for content_item in message.content %} + {%- if 'text' in content_item %} + {{- content_item.text }} + {%- endif %} + {%- endfor %} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and message.content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {%- if message.content is string %} + {{- message.content }} + {%- else %} + {%- for content in message.content %} + {%- if content.type == 'image' or 'image' in content or 'image_url' in content %} + {%- set image_count.value = image_count.value + 1 %} + {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%} + <|vision_start|><|image_pad|><|vision_end|> + {%- elif content.type == 'video' or 'video' in content %} + {%- set video_count.value = video_count.value + 1 %} + {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%} + <|vision_start|><|video_pad|><|vision_end|> + {%- elif 'text' in content %} + {{- content.text }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %}