From 999cb4d37f18fdb4ddf2fda6f658b75887b010bb Mon Sep 17 00:00:00 2001 From: Jacob Date: Thu, 2 Apr 2026 19:41:26 -0400 Subject: [PATCH] [Bugfix] Fix Gemma4 non-streaming reasoning parsing Co-authored-by: OpenAI Codex Signed-off-by: Jacob --- .../chat_completion/test_serving_chat.py | 46 ++++++++++++++++++- .../openai/chat_completion/serving.py | 23 +++++++++- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/chat_completion/test_serving_chat.py b/tests/entrypoints/openai/chat_completion/test_serving_chat.py index cb356e0e1986..fabb8b258f1a 100644 --- a/tests/entrypoints/openai/chat_completion/test_serving_chat.py +++ b/tests/entrypoints/openai/chat_completion/test_serving_chat.py @@ -23,7 +23,10 @@ ChatCompletionRequest, ChatCompletionResponse, ) -from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat +from vllm.entrypoints.openai.chat_completion.serving import ( + OpenAIServingChat, + _get_reasoning_parser_input_text, +) from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, RequestResponseMetadata, @@ -38,6 +41,7 @@ from vllm.exceptions import VLLMValidationError from vllm.inputs import TokensPrompt from vllm.outputs import CompletionOutput, RequestOutput +from vllm.reasoning.gemma4_reasoning_parser import Gemma4ReasoningParser from vllm.renderers.hf import HfRenderer from vllm.renderers.mistral import MistralRenderer from vllm.tokenizers import get_tokenizer @@ -2094,3 +2098,43 @@ def test_function_is_none(self): assert tc.type == "function" assert tc.function.name is None assert tc.function.arguments == '{"data": "value"}' + + +class _Gemma4ReasoningTokenizer: + def get_vocab(self) -> dict[str, int]: + return { + "<|channel>": 100, + "": 101, + } + + def decode(self, ids, skip_special_tokens: bool = False) -> str: + if ids != [100, 200, 101, 300]: + raise AssertionError(f"unexpected ids: {ids!r}") + if skip_special_tokens: + return "thought\nThe user is asking...\n2 + 2 = 4" + return "<|channel>thought\nThe user is asking...\n2 + 2 = 4" + + +def test_gemma4_non_streaming_reasoning_uses_token_ids(): + tokenizer = _Gemma4ReasoningTokenizer() + parser = Gemma4ReasoningParser(tokenizer) + output = CompletionOutput( + index=0, + text="thought\nThe user is asking...\n2 + 2 = 4", + token_ids=[100, 200, 101, 300], + cumulative_logprob=0.0, + logprobs=None, + ) + request = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{"role": "user", "content": "What is 2+2?"}], + ) + + parser_input = _get_reasoning_parser_input_text(output, tokenizer, parser) + reasoning, content = parser.extract_reasoning(parser_input, request=request) + + assert parser_input == ( + "<|channel>thought\nThe user is asking...\n2 + 2 = 4" + ) + assert reasoning == "The user is asking...\n" + assert content == "2 + 2 = 4" diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index a426836afd35..049afee12862 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -84,6 +84,22 @@ logger = init_logger(__name__) +def _get_reasoning_parser_input_text( + output: CompletionOutput, + tokenizer: TokenizerLike, + reasoning_parser: ReasoningParser, +) -> str: + """Build non-streaming parser input without losing reasoning markers.""" + start_token_id = getattr(reasoning_parser, "start_token_id", None) + end_token_id = getattr(reasoning_parser, "end_token_id", None) + if output.token_ids and ( + (start_token_id is not None and start_token_id in output.token_ids) + or (end_token_id is not None and end_token_id in output.token_ids) + ): + return tokenizer.decode(output.token_ids, skip_special_tokens=False) + return output.text + + class OpenAIServingChat(OpenAIServing): def __init__( self, @@ -1375,8 +1391,13 @@ async def chat_completion_full_generator( if reasoning_parser: # If the reasoning parser is enabled, # tool calls are extracted exclusively from the content. + parser_input_text = _get_reasoning_parser_input_text( + output, + tokenizer, + reasoning_parser, + ) reasoning, content = reasoning_parser.extract_reasoning( - output.text, request=request + parser_input_text, request=request ) if not request.include_reasoning: reasoning = None