Skip to content

Commit fd51456

Browse files
committed
[Bugfix] [Frontend] Fix Gemma 4 reasoning parser for multi-turn and tool calls
Add is_reasoning_end override so the parser correctly detects reasoning boundaries when tool calls are present or across multi-turn conversations. Also refactor test helpers for reuse. Signed-off-by: Ben Browning <bbrownin@redhat.com>
1 parent b060a04 commit fd51456

File tree

2 files changed

+61
-8
lines changed

2 files changed

+61
-8
lines changed

tests/reasoning/test_gemma4_reasoning_parser.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ def generic_tokenizer():
128128
"content": "Done",
129129
"is_reasoning_end": True,
130130
}
131+
# The model isn't reasoning if we're generating tool calls.
132+
TOOL_CALL_STARTED = {
133+
"output": "<|tool_call>",
134+
"reasoning": None,
135+
"content": "<|tool_call>",
136+
"is_reasoning_end": True,
137+
}
131138

132139
TEST_CASES = [
133140
pytest.param(False, INVALID_SIMPLE_NONSTREAMING, id="invalid_simple"),
@@ -159,17 +166,12 @@ def generic_tokenizer():
159166
),
160167
pytest.param(False, THOUGHT_PREFIX_DIVERGE, id="thought_prefix_diverge"),
161168
pytest.param(True, THOUGHT_PREFIX_DIVERGE, id="thought_prefix_diverge_streaming"),
169+
pytest.param(False, TOOL_CALL_STARTED, id="tool_call_started"),
170+
pytest.param(True, TOOL_CALL_STARTED, id="tool_call_started_streaming"),
162171
]
163172

164173

165-
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
166-
def test_gemma4_reasoning(
167-
streaming: bool,
168-
param_dict: dict,
169-
generic_tokenizer,
170-
):
171-
output = param_dict["output"]
172-
174+
def gemma4_encode_output(generic_tokenizer, output: str) -> list[int]:
173175
# Resolve token IDs dynamically from the real tokenizer
174176
vocab = generic_tokenizer.get_vocab()
175177
start_token_id = vocab["<|channel>"]
@@ -215,6 +217,18 @@ def _encode(text: str) -> list[int]:
215217
else:
216218
output_tokens += _encode(output)
217219

220+
return output_tokens
221+
222+
223+
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
224+
def test_gemma4_reasoning(
225+
streaming: bool,
226+
param_dict: dict,
227+
generic_tokenizer,
228+
):
229+
output = param_dict["output"]
230+
output_tokens = gemma4_encode_output(generic_tokenizer, output)
231+
218232
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
219233
generic_tokenizer
220234
)
@@ -246,3 +260,16 @@ def test_gemma4_adjust_request(generic_tokenizer):
246260
result = parser.adjust_request(request)
247261
assert result.skip_special_tokens is False
248262
assert result is request
263+
264+
265+
def test_gemma4_previous_turn_reasoning_is_reasoning_end(generic_tokenizer):
266+
output = (
267+
"<|channel>thought\n1st thought<channel|>1st content<turn|>\n"
268+
"<|turn>user\nThanks<|turn>model\n"
269+
)
270+
output_tokens = gemma4_encode_output(generic_tokenizer, output)
271+
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
272+
generic_tokenizer
273+
)
274+
is_reasoning_end = parser.is_reasoning_end(output_tokens)
275+
assert not is_reasoning_end

vllm/reasoning/gemma4_reasoning_parser.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
5252
# skip_special_tokens=True).
5353
self._reasoning_text: str = ""
5454
self._prefix_stripped: bool = False
55+
self.new_turn_token_id = self.vocab["<|turn>"]
56+
self.tool_call_token_id = self.vocab["<|tool_call>"]
57+
self.tool_response_token_id = self.vocab["<|tool_response>"]
5558

5659
def adjust_request(
5760
self, request: "ChatCompletionRequest | ResponsesRequest"
@@ -70,6 +73,29 @@ def end_token(self) -> str:
7073
"""The token that ends reasoning content."""
7174
return "<channel|>"
7275

76+
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
77+
start_token_id = self.start_token_id
78+
end_token_id = self.end_token_id
79+
new_turn_token_id = self.new_turn_token_id
80+
tool_call_token_id = self.tool_call_token_id
81+
tool_response_token_id = self.tool_response_token_id
82+
83+
# Search from the end of input_ids to find the last match.
84+
for i in range(len(input_ids) - 1, -1, -1):
85+
if input_ids[i] == start_token_id:
86+
return False
87+
if input_ids[i] == tool_call_token_id:
88+
# We're generating a tool call, so reasoning must be ended.
89+
return True
90+
if input_ids[i] in (new_turn_token_id, tool_response_token_id):
91+
# We found a new turn or tool response token so don't consider
92+
# reasoning ended yet, since the model starts new reasoning
93+
# after these tokens.
94+
return False
95+
if input_ids[i] == end_token_id:
96+
return True
97+
return False
98+
7399
# ------------------------------------------------------------------
74100
# Non-streaming path
75101
# ------------------------------------------------------------------

0 commit comments

Comments
 (0)