Skip to content

Commit 999cb4d

Browse files
committed
[Bugfix] Fix Gemma4 non-streaming reasoning parsing
Co-authored-by: OpenAI Codex Signed-off-by: Jacob <jaco8123@gmail.com>
1 parent bb39382 commit 999cb4d

2 files changed

Lines changed: 67 additions & 2 deletions

File tree

tests/entrypoints/openai/chat_completion/test_serving_chat.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
ChatCompletionRequest,
2424
ChatCompletionResponse,
2525
)
26-
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
26+
from vllm.entrypoints.openai.chat_completion.serving import (
27+
OpenAIServingChat,
28+
_get_reasoning_parser_input_text,
29+
)
2730
from vllm.entrypoints.openai.engine.protocol import (
2831
ErrorResponse,
2932
RequestResponseMetadata,
@@ -38,6 +41,7 @@
3841
from vllm.exceptions import VLLMValidationError
3942
from vllm.inputs import TokensPrompt
4043
from vllm.outputs import CompletionOutput, RequestOutput
44+
from vllm.reasoning.gemma4_reasoning_parser import Gemma4ReasoningParser
4145
from vllm.renderers.hf import HfRenderer
4246
from vllm.renderers.mistral import MistralRenderer
4347
from vllm.tokenizers import get_tokenizer
@@ -2094,3 +2098,43 @@ def test_function_is_none(self):
20942098
assert tc.type == "function"
20952099
assert tc.function.name is None
20962100
assert tc.function.arguments == '{"data": "value"}'
2101+
2102+
2103+
class _Gemma4ReasoningTokenizer:
2104+
def get_vocab(self) -> dict[str, int]:
2105+
return {
2106+
"<|channel>": 100,
2107+
"<channel|>": 101,
2108+
}
2109+
2110+
def decode(self, ids, skip_special_tokens: bool = False) -> str:
2111+
if ids != [100, 200, 101, 300]:
2112+
raise AssertionError(f"unexpected ids: {ids!r}")
2113+
if skip_special_tokens:
2114+
return "thought\nThe user is asking...\n2 + 2 = 4"
2115+
return "<|channel>thought\nThe user is asking...\n<channel|>2 + 2 = 4"
2116+
2117+
2118+
def test_gemma4_non_streaming_reasoning_uses_token_ids():
2119+
tokenizer = _Gemma4ReasoningTokenizer()
2120+
parser = Gemma4ReasoningParser(tokenizer)
2121+
output = CompletionOutput(
2122+
index=0,
2123+
text="thought\nThe user is asking...\n2 + 2 = 4",
2124+
token_ids=[100, 200, 101, 300],
2125+
cumulative_logprob=0.0,
2126+
logprobs=None,
2127+
)
2128+
request = ChatCompletionRequest(
2129+
model=MODEL_NAME,
2130+
messages=[{"role": "user", "content": "What is 2+2?"}],
2131+
)
2132+
2133+
parser_input = _get_reasoning_parser_input_text(output, tokenizer, parser)
2134+
reasoning, content = parser.extract_reasoning(parser_input, request=request)
2135+
2136+
assert parser_input == (
2137+
"<|channel>thought\nThe user is asking...\n<channel|>2 + 2 = 4"
2138+
)
2139+
assert reasoning == "The user is asking...\n"
2140+
assert content == "2 + 2 = 4"

vllm/entrypoints/openai/chat_completion/serving.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,22 @@
8484
logger = init_logger(__name__)
8585

8686

87+
def _get_reasoning_parser_input_text(
88+
output: CompletionOutput,
89+
tokenizer: TokenizerLike,
90+
reasoning_parser: ReasoningParser,
91+
) -> str:
92+
"""Build non-streaming parser input without losing reasoning markers."""
93+
start_token_id = getattr(reasoning_parser, "start_token_id", None)
94+
end_token_id = getattr(reasoning_parser, "end_token_id", None)
95+
if output.token_ids and (
96+
(start_token_id is not None and start_token_id in output.token_ids)
97+
or (end_token_id is not None and end_token_id in output.token_ids)
98+
):
99+
return tokenizer.decode(output.token_ids, skip_special_tokens=False)
100+
return output.text
101+
102+
87103
class OpenAIServingChat(OpenAIServing):
88104
def __init__(
89105
self,
@@ -1375,8 +1391,13 @@ async def chat_completion_full_generator(
13751391
if reasoning_parser:
13761392
# If the reasoning parser is enabled,
13771393
# tool calls are extracted exclusively from the content.
1394+
parser_input_text = _get_reasoning_parser_input_text(
1395+
output,
1396+
tokenizer,
1397+
reasoning_parser,
1398+
)
13781399
reasoning, content = reasoning_parser.extract_reasoning(
1379-
output.text, request=request
1400+
parser_input_text, request=request
13801401
)
13811402
if not request.include_reasoning:
13821403
reasoning = None

0 commit comments

Comments
 (0)