diff --git a/examples/tool_chat_template_gemma4.jinja b/examples/tool_chat_template_gemma4.jinja new file mode 100644 index 000000000000..15c5238ac332 --- /dev/null +++ b/examples/tool_chat_template_gemma4.jinja @@ -0,0 +1,331 @@ +{%- macro format_parameters(properties, required) -%} + {%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%} + {%- set ns = namespace(found_first=false) -%} + {%- for key, value in properties | dictsort -%} + {%- set add_comma = false -%} + {%- if key not in standard_keys -%} + {%- if ns.found_first %},{% endif -%} + {%- set ns.found_first = true -%} + {{ key }}:{ + {%- if value['description'] -%} + description:<|"|>{{ value['description'] }}<|"|> + {%- set add_comma = true -%} + {%- endif -%} + {%- if value['nullable'] %} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + nullable:true + {%- endif -%} + {%- if value['type'] | upper == 'STRING' -%} + {%- if value['enum'] -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + enum:{{ format_argument(value['enum']) }} + {%- endif -%} + {%- elif value['type'] | upper == 'OBJECT' -%} + ,properties:{ + {%- if value['properties'] is defined and value['properties'] is mapping -%} + {{- format_parameters(value['properties'], value['required'] | default([])) -}} + {%- elif value is mapping -%} + {{- format_parameters(value, value['required'] | default([])) -}} + {%- endif -%} + } + {%- if value['required'] -%} + ,required:[ + {%- for item in value['required'] | default([]) -%} + <|"|>{{- item -}}<|"|> + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ] + {%- endif -%} + {%- elif value['type'] | upper == 'ARRAY' -%} + {%- if value['items'] is mapping and value['items'] -%} + ,items:{ + {%- set ns_items = namespace(found_first=false) -%} + {%- for item_key, item_value in value['items'] | dictsort -%} + {%- if item_value is not none -%} + {%- if ns_items.found_first %},{% endif -%} + {%- set ns_items.found_first = true -%} + {%- if item_key == 'properties' -%} + properties:{ + {%- if item_value is mapping -%} + {{- format_parameters(item_value, value['items']['required'] | default([])) -}} + {%- endif -%} + } + {%- elif item_key == 'required' -%} + required:[ + {%- for req_item in item_value -%} + <|"|>{{- req_item -}}<|"|> + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ] + {%- elif item_key == 'type' -%} + {%- if item_value is string -%} + type:{{ format_argument(item_value | upper) }} + {%- else -%} + type:{{ format_argument(item_value | map('upper') | list) }} + {%- endif -%} + {%- else -%} + {{ item_key }}:{{ format_argument(item_value) }} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + } + {%- endif -%} + {%- endif -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + type:<|"|>{{ value['type'] | upper }}<|"|>} + {%- endif -%} + {%- endfor -%} +{%- endmacro -%} +{%- macro format_function_declaration(tool_data) -%} + declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|> + {%- set params = tool_data['function']['parameters'] -%} + {%- if params -%} + ,parameters:{ + {%- if params['properties'] -%} + properties:{ {{- format_parameters(params['properties'], params['required']) -}} }, + {%- endif -%} + {%- if params['required'] -%} + required:[ + {%- for item in params['required'] -%} + <|"|>{{- item -}}<|"|> + {{- ',' if not loop.last -}} + {%- endfor -%} + ], + {%- endif -%} + {%- if params['type'] -%} + type:<|"|>{{- params['type'] | upper -}}<|"|>} + {%- endif -%} + {%- endif -%} + {%- if 'response' in tool_data['function'] -%} + {%- set response_declaration = tool_data['function']['response'] -%} + ,response:{ + {%- if response_declaration['description'] -%} + description:<|"|>{{- response_declaration['description'] -}}<|"|>, + {%- endif -%} + {%- if response_declaration['type'] | upper == 'OBJECT' -%} + type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>} + {%- endif -%} + {%- endif -%} + } +{%- endmacro -%} +{%- macro format_argument(argument, escape_keys=True) -%} + {%- if argument is string -%} + {{- '<|"|>' + argument + '<|"|>' -}} + {%- elif argument is boolean -%} + {{- 'true' if argument else 'false' -}} + {%- elif argument is mapping -%} + {{- '{' -}} + {%- set ns = namespace(found_first=false) -%} + {%- for key, value in argument | dictsort -%} + {%- if ns.found_first %},{% endif -%} + {%- set ns.found_first = true -%} + {%- if escape_keys -%} + {{- '<|"|>' + key + '<|"|>' -}} + {%- else -%} + {{- key -}} + {%- endif -%} + :{{- format_argument(value, escape_keys=escape_keys) -}} + {%- endfor -%} + {{- '}' -}} + {%- elif argument is sequence -%} + {{- '[' -}} + {%- for item in argument -%} + {{- format_argument(item, escape_keys=escape_keys) -}} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + {{- ']' -}} + {%- else -%} + {{- argument -}} + {%- endif -%} +{%- endmacro -%} +{%- macro strip_thinking(text) -%} + {%- set ns = namespace(result='') -%} + {%- for part in text.split('') -%} + {%- if '<|channel>' in part -%} + {%- set ns.result = ns.result + part.split('<|channel>')[0] -%} + {%- else -%} + {%- set ns.result = ns.result + part -%} + {%- endif -%} + {%- endfor -%} + {{- ns.result | trim -}} +{%- endmacro -%} + +{%- macro format_tool_response_block(tool_name, response) -%} + {{- '<|tool_response>' -}} + {%- if response is mapping -%} + {{- 'response:' + tool_name + '{' -}} + {%- for key, value in response | dictsort -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + {{- '}' -}} + {%- else -%} + {{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}} + {%- endif -%} + {{- '' -}} +{%- endmacro -%} + +{%- set ns = namespace(prev_message_type=None) -%} +{%- set loop_messages = messages -%} +{{ bos_token }} +{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%} + {{- '<|turn>system\n' -}} + + {%- if enable_thinking is defined and enable_thinking -%} + {{- '<|think|>' -}} + {%- set ns.prev_message_type = 'think' -%} + {%- endif -%} + + {%- if messages[0]['role'] in ['system', 'developer'] -%} + {{- messages[0]['content'] | trim -}} + {%- set loop_messages = messages[1:] -%} + {%- endif -%} + + {%- if tools -%} + {%- for tool in tools %} + {{- '<|tool>' -}} + {{- format_function_declaration(tool) | trim -}} + {{- '' -}} + {%- endfor %} + {%- set ns.prev_message_type = 'tool' -%} + {%- endif -%} + + {{- '\n' -}} +{%- endif %} + +{%- set ns_turn = namespace(last_user_idx=-1) -%} +{%- for i in range(loop_messages | length) -%} + {%- if loop_messages[i]['role'] == 'user' -%} + {%- set ns_turn.last_user_idx = i -%} + {%- endif -%} +{%- endfor -%} + +{%- for message in loop_messages -%} + {%- if message['role'] != 'tool' -%} + {%- set ns.prev_message_type = None -%} + {%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%} + {#- OpenAI may emit multiple assistant messages in one tool loop (user → asst → tool → asst → tool). + Only the first of those should open <|turn>model; later ones continue the same model turn. -#} + {%- set prev_nt = namespace(role=None, found=false) -%} + {%- if loop.index0 > 0 -%} + {%- for j in range(loop.index0 - 1, -1, -1) -%} + {%- if not prev_nt.found -%} + {%- if loop_messages[j]['role'] != 'tool' -%} + {%- set prev_nt.role = loop_messages[j]['role'] -%} + {%- set prev_nt.found = true -%} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%} + {%- if not continue_same_model_turn -%} + {{- '<|turn>' + role + '\n' }} + {%- endif -%} + + {%- if message.get('reasoning') and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%} + {{- '<|channel>thought\n' + message['reasoning'] + '\n'}} + {%- endif -%} + + {%- if message['tool_calls'] -%} + {%- for tool_call in message['tool_calls'] -%} + {%- set function = tool_call['function'] -%} + {{- '<|tool_call>call:' + function['name'] + '{' -}} + {%- if function['arguments'] is mapping -%} + {%- set ns_args = namespace(found_first=false) -%} + {%- for key, value in function['arguments'] | dictsort -%} + {%- if ns_args.found_first %},{% endif -%} + {%- set ns_args.found_first = true -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- endfor -%} + {%- elif function['arguments'] is string -%} + {{- function['arguments'] -}} + {%- endif -%} + {{- '}' -}} + {%- endfor -%} + {%- set ns.prev_message_type = 'tool_call' -%} + {%- endif -%} + + {%- set ns_tr_out = namespace(flag=false) -%} + {%- if message.get('tool_responses') -%} + {#- Legacy: tool_responses embedded on the assistant message -#} + {%- for tool_response in message['tool_responses'] -%} + {{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}} + {%- set ns_tr_out.flag = true -%} + {%- set ns.prev_message_type = 'tool_response' -%} + {%- endfor -%} + {%- elif message.get('tool_calls') -%} + {#- OpenAI Chat Completions: consecutive following messages with role "tool" (no break/continue; range scan) -#} + {%- set ns_tool_scan = namespace(stopped=false) -%} + {%- for k in range(loop.index0 + 1, loop_messages | length) -%} + {%- if ns_tool_scan.stopped -%} + {%- elif loop_messages[k]['role'] != 'tool' -%} + {%- set ns_tool_scan.stopped = true -%} + {%- else -%} + {%- set follow = loop_messages[k] -%} + {%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%} + {%- for tc in message['tool_calls'] -%} + {%- if tc.get('id') == follow.get('tool_call_id') -%} + {%- set ns_tname.name = tc['function']['name'] -%} + {%- endif -%} + {%- endfor -%} + {%- set tool_body = follow.get('content') -%} + {%- if tool_body is string -%} + {{- format_tool_response_block(ns_tname.name, tool_body) -}} + {%- elif tool_body is sequence and tool_body is not string -%} + {%- set ns_txt = namespace(s='') -%} + {%- for part in tool_body -%} + {%- if part.get('type') == 'text' -%} + {%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%} + {%- endif -%} + {%- endfor -%} + {{- format_tool_response_block(ns_tname.name, ns_txt.s) -}} + {%- else -%} + {{- format_tool_response_block(ns_tname.name, tool_body) -}} + {%- endif -%} + {%- set ns_tr_out.flag = true -%} + {%- set ns.prev_message_type = 'tool_response' -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + + {%- if message['content'] is string -%} + {%- if role == 'model' -%} + {{- strip_thinking(message['content']) -}} + {%- else -%} + {{- message['content'] | trim -}} + {%- endif -%} + {%- elif message['content'] is sequence -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'text' -%} + {%- if role == 'model' -%} + {{- strip_thinking(item['text']) -}} + {%- else -%} + {{- item['text'] | trim -}} + {%- endif -%} + {%- elif item['type'] == 'image' -%} + {{- '\n\n<|image|>\n\n' -}} + {%- set ns.prev_message_type = 'image' -%} + {%- elif item['type'] == 'audio' -%} + {{- '<|audio|>' -}} + {%- set ns.prev_message_type = 'audio' -%} + {%- elif item['type'] == 'video' -%} + {{- '\n\n<|video|>\n\n' -}} + {%- set ns.prev_message_type = 'video' -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + + {%- if not (ns_tr_out.flag and not message.get('content')) -%} + {{- '\n' -}} + {%- endif -%} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt -%} + {%- if ns.prev_message_type != 'tool_response' -%} + {{- '<|turn>model\n' -}} + {%- endif -%} + {%- if not enable_thinking | default(false) -%} + {{- '<|channel>thought\n' -}} + {%- endif -%} +{%- endif -%} diff --git a/tests/reasoning/test_gemma4_reasoning_parser.py b/tests/reasoning/test_gemma4_reasoning_parser.py index cdda7dea51d4..699fc509d828 100644 --- a/tests/reasoning/test_gemma4_reasoning_parser.py +++ b/tests/reasoning/test_gemma4_reasoning_parser.py @@ -4,6 +4,9 @@ import pytest from tests.reasoning.utils import run_reasoning_extraction +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, +) from vllm.reasoning import ReasoningParser, ReasoningParserManager # Using mistral tokenizer as a generic mock since the actual model is not on HF @@ -100,6 +103,39 @@ def generic_tokenizer(): "is_reasoning_end": True, } +THOUGHT_PREFIX = { + "output": "<|channel>thought\nActual reasoning hereFinal answer", + "reasoning": "Actual reasoning here", + "content": "Final answer", + "is_reasoning_end": True, +} +THOUGHT_PREFIX_ONLY = { + "output": "<|channel>thought\n", + "reasoning": "", + "content": None, + "is_reasoning_end": True, +} +THOUGHT_PREFIX_MULTILINE = { + "output": "<|channel>thought\nLine1\nLine2Answer", + "reasoning": "Line1\nLine2", + "content": "Answer", + "is_reasoning_end": True, +} +# "thousand" starts like "thought" but diverges — exercises Case 2→3 in streaming. +THOUGHT_PREFIX_DIVERGE = { + "output": "<|channel>thousand reasonsDone", + "reasoning": "thousand reasons", + "content": "Done", + "is_reasoning_end": True, +} +# The model isn't reasoning if we're generating tool calls. +TOOL_CALL_STARTED = { + "output": "<|tool_call>", + "reasoning": None, + "content": "<|tool_call>", + "is_reasoning_end": True, +} + TEST_CASES = [ pytest.param(False, INVALID_SIMPLE_NONSTREAMING, id="invalid_simple"), pytest.param(True, INVALID_SIMPLE_STREAMING, id="invalid_simple_streaming"), @@ -120,17 +156,22 @@ def generic_tokenizer(): pytest.param(False, EMPTY, id="empty"), pytest.param(False, NEW_LINE_NONSTREAMING, id="new_line"), pytest.param(True, NEW_LINE_STREAMING, id="new_line_streaming"), + pytest.param(False, THOUGHT_PREFIX, id="thought_prefix"), + pytest.param(True, THOUGHT_PREFIX, id="thought_prefix_streaming"), + pytest.param(False, THOUGHT_PREFIX_ONLY, id="thought_prefix_only"), + pytest.param(True, THOUGHT_PREFIX_ONLY, id="thought_prefix_only_streaming"), + pytest.param(False, THOUGHT_PREFIX_MULTILINE, id="thought_prefix_multiline"), + pytest.param( + True, THOUGHT_PREFIX_MULTILINE, id="thought_prefix_multiline_streaming" + ), + pytest.param(False, THOUGHT_PREFIX_DIVERGE, id="thought_prefix_diverge"), + pytest.param(True, THOUGHT_PREFIX_DIVERGE, id="thought_prefix_diverge_streaming"), + pytest.param(False, TOOL_CALL_STARTED, id="tool_call_started"), + pytest.param(True, TOOL_CALL_STARTED, id="tool_call_started_streaming"), ] -@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) -def test_gemma4_reasoning( - streaming: bool, - param_dict: dict, - generic_tokenizer, -): - output = param_dict["output"] - +def gemma4_encode_output(generic_tokenizer, output: str) -> list[int]: # Resolve token IDs dynamically from the real tokenizer vocab = generic_tokenizer.get_vocab() start_token_id = vocab["<|channel>"] @@ -176,6 +217,18 @@ def _encode(text: str) -> list[int]: else: output_tokens += _encode(output) + return output_tokens + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_gemma4_reasoning( + streaming: bool, + param_dict: dict, + generic_tokenizer, +): + output = param_dict["output"] + output_tokens = gemma4_encode_output(generic_tokenizer, output) + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( generic_tokenizer ) @@ -194,3 +247,29 @@ def _encode(text: str) -> list[int]: # Test is_reasoning_end is_reasoning_end = parser.is_reasoning_end(output_tokens) assert is_reasoning_end == param_dict["is_reasoning_end"] + + +def test_gemma4_adjust_request(generic_tokenizer): + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + generic_tokenizer + ) + + request = ChatCompletionRequest(messages=[], model="test-model") + assert request.skip_special_tokens is True + + result = parser.adjust_request(request) + assert result.skip_special_tokens is False + assert result is request + + +def test_gemma4_previous_turn_reasoning_is_reasoning_end(generic_tokenizer): + output = ( + "<|channel>thought\n1st thought1st content\n" + "<|turn>user\nThanks<|turn>model\n" + ) + output_tokens = gemma4_encode_output(generic_tokenizer, output) + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + generic_tokenizer + ) + is_reasoning_end = parser.is_reasoning_end(output_tokens) + assert not is_reasoning_end diff --git a/tests/renderers/test_gemma4_chat_template.py b/tests/renderers/test_gemma4_chat_template.py new file mode 100644 index 000000000000..a4a0b41d053a --- /dev/null +++ b/tests/renderers/test_gemma4_chat_template.py @@ -0,0 +1,345 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Tests for Gemma4 chat template rendering.""" + +from pathlib import Path + +import jinja2.sandbox +import pytest + +TEMPLATE_PATH = ( + Path(__file__).resolve().parent.parent.parent + / "examples" + / "tool_chat_template_gemma4.jinja" +) + + +@pytest.fixture(scope="module") +def gemma4_template(): + """Load and compile the Gemma4 chat template.""" + template_str = TEMPLATE_PATH.read_text() + env = jinja2.sandbox.ImmutableSandboxedEnvironment() + return env.from_string(template_str) + + +def _render(template, messages, **kwargs): + """Render the template with sensible defaults.""" + kwargs.setdefault("bos_token", "") + kwargs.setdefault("add_generation_prompt", False) + return template.render(messages=messages, **kwargs) + + +class TestGemma4ChatTemplate: + def test_basic_multiturn_thinking_disabled(self, gemma4_template): + """With enable_thinking=False (default), generation prompt ends with + an empty thought channel to suppress thinking.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + result = _render(gemma4_template, messages, add_generation_prompt=True) + assert "<|turn>user\n" in result + assert "<|turn>model\n" in result + assert "Hello" in result + assert "Hi there!" in result + assert "How are you?" in result + assert result.rstrip("\n").endswith("<|channel>thought\n") + + def test_basic_multiturn_thinking_enabled(self, gemma4_template): + """With enable_thinking=True, generation prompt ends with model + turn opener (no thought suppression).""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + result = _render( + gemma4_template, + messages, + add_generation_prompt=True, + enable_thinking=True, + ) + assert "<|turn>user\n" in result + assert "<|turn>model\n" in result + assert "Hello" in result + assert "Hi there!" in result + assert "How are you?" in result + assert result.rstrip("\n").endswith("<|turn>model") + + def test_system_message(self, gemma4_template): + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + ] + result = _render(gemma4_template, messages) + assert "<|turn>system\n" in result + assert "You are helpful." in result + + def test_thinking_enabled(self, gemma4_template): + messages = [{"role": "user", "content": "Think about this"}] + result = _render( + gemma4_template, + messages, + add_generation_prompt=True, + enable_thinking=True, + ) + assert "<|think|>" in result + assert "<|turn>system\n" in result + + def test_tool_declarations(self, gemma4_template): + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a city", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + }, + } + ] + messages = [{"role": "user", "content": "What is the weather?"}] + result = _render( + gemma4_template, + messages, + tools=tools, + add_generation_prompt=True, + ) + assert "<|tool>" in result + assert "declaration:get_weather" in result + assert "" in result + assert '<|"|>City name<|"|>' in result + + def test_tool_calls_in_assistant(self, gemma4_template): + messages = [ + {"role": "user", "content": "Weather in London?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "function": { + "name": "get_weather", + "arguments": {"city": "London"}, + }, + } + ], + }, + ] + result = _render(gemma4_template, messages) + assert "<|tool_call>call:get_weather{" in result + assert "}" in result + assert '<|"|>London<|"|>' in result + + def test_tool_responses_openai_style(self, gemma4_template): + """role='tool' messages are formatted as <|tool_response> blocks + with content dumped as-is.""" + messages = [ + {"role": "user", "content": "Weather?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "function": { + "name": "get_weather", + "arguments": {"city": "London"}, + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_1", + "content": '{"temperature": 15, "condition": "sunny"}', + }, + ] + result = _render(gemma4_template, messages, add_generation_prompt=True) + assert "<|tool_response>" in result + assert "response:get_weather{" in result + assert "" in result + assert '"temperature": 15' in result + + def test_tool_responses_legacy_style(self, gemma4_template): + """tool_responses embedded on the assistant message.""" + messages = [ + {"role": "user", "content": "Weather?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "get_weather", + "arguments": {"city": "London"}, + }, + } + ], + "tool_responses": [ + { + "name": "get_weather", + "response": {"temperature": 20}, + } + ], + }, + ] + result = _render(gemma4_template, messages) + assert "<|tool_response>" in result + assert "response:get_weather{" in result + assert "temperature:" in result + + def test_generation_prompt_not_after_tool_response(self, gemma4_template): + """add_generation_prompt=True should NOT add <|turn>model when the + last message type was tool_response (the model turn continues).""" + messages = [ + {"role": "user", "content": "Weather?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "function": { + "name": "get_weather", + "arguments": {"city": "London"}, + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_1", + "content": "sunny", + }, + ] + result = _render(gemma4_template, messages, add_generation_prompt=True) + assert not result.strip().endswith("<|turn>model\n") + + def test_reasoning_in_tool_chains(self, gemma4_template): + """reasoning field on assistant with tool_calls after last user + message emits <|channel>thought\\n....""" + messages = [ + {"role": "user", "content": "Calculate something"}, + { + "role": "assistant", + "content": "", + "reasoning": "Let me think about this...", + "tool_calls": [ + { + "function": { + "name": "calculator", + "arguments": {"expr": "2+2"}, + }, + } + ], + }, + ] + result = _render(gemma4_template, messages) + assert "<|channel>thought\n" in result + assert "Let me think about this..." in result + assert "" in result + + def test_reasoning_not_before_last_user(self, gemma4_template): + """reasoning on assistant BEFORE the last user message is dropped.""" + messages = [ + {"role": "user", "content": "First"}, + { + "role": "assistant", + "content": "Response", + "reasoning": "Old reasoning that should be dropped", + "tool_calls": [ + { + "function": { + "name": "fn", + "arguments": {}, + }, + } + ], + }, + {"role": "user", "content": "Second"}, + ] + result = _render(gemma4_template, messages, add_generation_prompt=True) + assert "Old reasoning" not in result + + def test_strip_thinking_in_model_content(self, gemma4_template): + """<|channel>... in model content is stripped by the + strip_thinking macro.""" + messages = [ + {"role": "user", "content": "Hi"}, + { + "role": "assistant", + "content": ("<|channel>internal thoughtVisible answer"), + }, + ] + result = _render(gemma4_template, messages) + assert "internal thought" not in result + assert "Visible answer" in result + + def test_multi_turn_tool_chain(self, gemma4_template): + """assistant->tool->assistant->tool produces exactly one + <|turn>model (later assistants continue the same turn).""" + messages = [ + {"role": "user", "content": "Do two things"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "c1", + "function": {"name": "step1", "arguments": {}}, + }, + ], + }, + {"role": "tool", "tool_call_id": "c1", "content": "result1"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "c2", + "function": {"name": "step2", "arguments": {}}, + }, + ], + }, + {"role": "tool", "tool_call_id": "c2", "content": "result2"}, + ] + result = _render(gemma4_template, messages, add_generation_prompt=True) + assert result.count("<|turn>model\n") == 1 + + def test_format_argument_types(self, gemma4_template): + """Strings wrapped in <|"|>, booleans as true/false, numbers bare.""" + messages = [ + {"role": "user", "content": "Test"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "test_fn", + "arguments": { + "name": "Alice", + "active": True, + "count": 42, + }, + }, + } + ], + }, + ] + result = _render(gemma4_template, messages) + assert '<|"|>Alice<|"|>' in result + assert "active:true" in result + assert "count:42" in result diff --git a/tests/tool_parsers/test_gemma4_tool_parser.py b/tests/tool_parsers/test_gemma4_tool_parser.py index 2dcfdc65c39b..54f1e7eed8ff 100644 --- a/tests/tool_parsers/test_gemma4_tool_parser.py +++ b/tests/tool_parsers/test_gemma4_tool_parser.py @@ -114,6 +114,19 @@ def test_empty_value(self): result = _parse_gemma4_args("key:") assert result == {"key": ""} + def test_empty_value_partial_withheld(self): + """Key with no value is withheld in partial mode to avoid premature emission.""" + result = _parse_gemma4_args("key:", partial=True) + assert result == {} + # also with a space after the colon + result = _parse_gemma4_args("key: ", partial=True) + assert result == {} + + def test_empty_value_after_other_keys_partial_withheld(self): + """Trailing key with no value is withheld; earlier keys are kept.""" + result = _parse_gemma4_args('name:<|"|>test<|"|>,flag:', partial=True) + assert result == {"name": "test"} + class TestParseGemma4Array: def test_string_array(self): @@ -636,3 +649,30 @@ def test_streaming_html_argument_does_not_duplicate_tag_prefixes( ' \n' ' \n' ) + + def test_streaming_trailing_bare_bool_not_duplicated(self, parser, mock_request): + """Trailing bare boolean must not be streamed twice.""" + chunks = [ + "<|tool_call>", + "call:Edit{", + 'file_path:<|"|>src/env.py<|"|>,', + 'old_string:<|"|>old_val<|"|>,', + 'new_string:<|"|>new_val<|"|>,', + "replace_all:", + "false}", + "", + ] + + results = self._simulate_streaming(parser, mock_request, chunks) + args_text = self._collect_arguments(results) + assert args_text, "No arguments were streamed" + + parsed_args = json.loads(args_text) + assert parsed_args == { + "file_path": "src/env.py", + "old_string": "old_val", + "new_string": "new_val", + "replace_all": False, + } + + assert args_text.count("replace_all") == 1 diff --git a/vllm/entrypoints/anthropic/serving.py b/vllm/entrypoints/anthropic/serving.py index 9270a49d1d1b..5635cfbd3596 100644 --- a/vllm/entrypoints/anthropic/serving.py +++ b/vllm/entrypoints/anthropic/serving.py @@ -170,7 +170,8 @@ def _convert_messages( else: cls._convert_message_content(msg, openai_msg, openai_messages) - openai_messages.append(openai_msg) + if not (msg.role == "user" and "content" not in openai_msg): + openai_messages.append(openai_msg) @classmethod def _convert_message_content( diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index bf23c7e2e96b..2b6cb810ea31 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -379,6 +379,7 @@ async def init_app_state( enable_auto_tools=args.enable_auto_tool_choice, exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, tool_parser=args.tool_call_parser, + reasoning_parser=args.structured_outputs_config.reasoning_parser, default_chat_template_kwargs=args.default_chat_template_kwargs, log_error_stack=args.log_error_stack, ) @@ -481,6 +482,7 @@ async def init_render_app_state( enable_auto_tools=args.enable_auto_tool_choice, exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, tool_parser=args.tool_call_parser, + reasoning_parser=args.structured_outputs_config.reasoning_parser, default_chat_template_kwargs=args.default_chat_template_kwargs, log_error_stack=args.log_error_stack, ) diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index 086101c28d8e..7edeb8ab0cbe 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -594,6 +594,7 @@ async def _make_request( default_template_kwargs=None, tool_dicts=tool_dicts, tool_parser=self.parser.tool_parser_cls if self.parser else None, + reasoning_parser=self.parser.reasoning_parser_cls if self.parser else None, ) return messages, engine_inputs @@ -618,6 +619,7 @@ async def _render_next_turn( default_template_kwargs=None, tool_dicts=tool_dicts, tool_parser=tool_parser, + reasoning_parser=self.parser.reasoning_parser_cls if self.parser else None, ) return engine_inputs diff --git a/vllm/entrypoints/serve/render/serving.py b/vllm/entrypoints/serve/render/serving.py index 43ea5127b3be..2aaa83e7564b 100644 --- a/vllm/entrypoints/serve/render/serving.py +++ b/vllm/entrypoints/serve/render/serving.py @@ -44,6 +44,7 @@ ) from vllm.logger import init_logger from vllm.parser import ParserManager +from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.renderers import BaseRenderer, merge_kwargs from vllm.renderers.inputs.preprocess import ( extract_prompt_components, @@ -74,6 +75,7 @@ def __init__( enable_auto_tools: bool = False, exclude_tools_when_tool_choice_none: bool = False, tool_parser: str | None = None, + reasoning_parser: str | None = None, default_chat_template_kwargs: dict[str, Any] | None = None, log_error_stack: bool = False, ) -> None: @@ -94,6 +96,11 @@ def __init__( enable_auto_tools=enable_auto_tools, model_name=model_config.model, ) + self.reasoning_parser: type[ReasoningParser] | None = ( + ParserManager.get_reasoning_parser( + reasoning_parser_name=reasoning_parser, + ) + ) self.default_chat_template_kwargs: dict[str, Any] = ( default_chat_template_kwargs or {} ) @@ -245,6 +252,7 @@ async def render_chat( default_template_kwargs=self.default_chat_template_kwargs, tool_dicts=tool_dicts, tool_parser=tool_parser, + reasoning_parser=self.reasoning_parser, ) else: # For GPT-OSS. @@ -503,6 +511,7 @@ async def preprocess_chat( default_template_kwargs: dict[str, Any] | None, tool_dicts: list[dict[str, Any]] | None = None, tool_parser: type[ToolParser] | None = None, + reasoning_parser: type[ReasoningParser] | None = None, *, skip_mm_cache: bool = False, ) -> tuple[list[ConversationMessage], list[EngineInput]]: @@ -539,6 +548,10 @@ async def preprocess_chat( skip_mm_cache=skip_mm_cache, ) + if reasoning_parser is not None: + tokenizer = renderer.get_tokenizer() + request = reasoning_parser(tokenizer).adjust_request(request=request) + # tool parsing is done only if a tool_parser has been set and if # tool_choice is not "none" (if tool_choice is "none" but a tool_parser # is set, we want to prevent parsing a tool_call hallucinated by the LLM diff --git a/vllm/parser/abstract_parser.py b/vllm/parser/abstract_parser.py index 30b4c4ebee7e..d777f80f40b1 100644 --- a/vllm/parser/abstract_parser.py +++ b/vllm/parser/abstract_parser.py @@ -496,6 +496,15 @@ def _parse_tool_calls( # No tool calls return [], content + def adjust_request( + self, request: ChatCompletionRequest | ResponsesRequest + ) -> ChatCompletionRequest | ResponsesRequest: + if self._reasoning_parser is not None: + request = self._reasoning_parser.adjust_request(request) + if self._tool_parser is not None: + request = self._tool_parser.adjust_request(request) + return request + def extract_reasoning_streaming( self, previous_text: str, diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index 5271a307075e..7244b9fac84f 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -6,7 +6,7 @@ from abc import abstractmethod from collections.abc import Callable, Iterable, Sequence from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from vllm.entrypoints.mcp.tool_server import ToolServer from vllm.logger import init_logger @@ -150,6 +150,12 @@ def extract_reasoning_streaming( previously been parsed and extracted (see constructor) """ + def adjust_request( + self, request: "ChatCompletionRequest | ResponsesRequest" + ) -> "ChatCompletionRequest | ResponsesRequest": + """Adjust request parameters; override in subclasses as needed.""" + return request + def prepare_structured_tag( self, original_tag: str | None, @@ -298,7 +304,7 @@ def _decorator(obj: type[ReasoningParser]) -> type[ReasoningParser]: if isinstance(name, str): names = [name] elif is_list_of(name, str): - names = name + names = cast(list[str], name) else: names = [class_name] diff --git a/vllm/reasoning/gemma4_reasoning_parser.py b/vllm/reasoning/gemma4_reasoning_parser.py index efcdcca237b7..6f2241603f9a 100644 --- a/vllm/reasoning/gemma4_reasoning_parser.py +++ b/vllm/reasoning/gemma4_reasoning_parser.py @@ -52,6 +52,16 @@ def __init__(self, tokenizer: TokenizerLike, *args, **kwargs): # skip_special_tokens=True). self._reasoning_text: str = "" self._prefix_stripped: bool = False + self.new_turn_token_id = self.vocab["<|turn>"] + self.tool_call_token_id = self.vocab["<|tool_call>"] + self.tool_response_token_id = self.vocab["<|tool_response>"] + + def adjust_request( + self, request: "ChatCompletionRequest | ResponsesRequest" + ) -> "ChatCompletionRequest | ResponsesRequest": + """Disable special-token stripping to preserve boundary tokens.""" + request.skip_special_tokens = False + return request @property def start_token(self) -> str: @@ -63,6 +73,29 @@ def end_token(self) -> str: """The token that ends reasoning content.""" return "" + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + start_token_id = self.start_token_id + end_token_id = self.end_token_id + new_turn_token_id = self.new_turn_token_id + tool_call_token_id = self.tool_call_token_id + tool_response_token_id = self.tool_response_token_id + + # Search from the end of input_ids to find the last match. + for i in range(len(input_ids) - 1, -1, -1): + if input_ids[i] == start_token_id: + return False + if input_ids[i] == tool_call_token_id: + # We're generating a tool call, so reasoning must be ended. + return True + if input_ids[i] in (new_turn_token_id, tool_response_token_id): + # We found a new turn or tool response token so don't consider + # reasoning ended yet, since the model starts new reasoning + # after these tokens. + return False + if input_ids[i] == end_token_id: + return True + return False + # ------------------------------------------------------------------ # Non-streaming path # ------------------------------------------------------------------ @@ -159,11 +192,10 @@ def extract_reasoning_streaming( result.reasoning = stripped return result else: - # This entire delta was prefix — suppress it. - # Don't set _prefix_stripped yet; there may be more - # prefix chars to consume in the next delta. if len(self._reasoning_text) >= prefix_len: self._prefix_stripped = True + result.reasoning = "" + return result return None # Case 2: Accumulated text is a strict prefix of diff --git a/vllm/tool_parsers/gemma4_tool_parser.py b/vllm/tool_parsers/gemma4_tool_parser.py index b3f71ab2fb51..b96fb919c80c 100644 --- a/vllm/tool_parsers/gemma4_tool_parser.py +++ b/vllm/tool_parsers/gemma4_tool_parser.py @@ -122,14 +122,16 @@ def _parse_gemma4_args(args_str: str, *, partial: bool = False) -> dict: # Parse value if i >= n: - result[key] = "" + if not partial: + result[key] = "" break # Skip whitespace after ':' while i < n and args_str[i] in (" ", "\n", "\t"): i += 1 if i >= n: - result[key] = "" + if not partial: + result[key] = "" break # String value: <|"|>...<|"|>