Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion tests/entrypoints/openai/chat_completion/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
ChatCompletionRequest,
ChatCompletionResponse,
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.chat_completion.serving import (
OpenAIServingChat,
_get_reasoning_parser_input_text,
)
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
RequestResponseMetadata,
Expand All @@ -38,6 +41,7 @@
from vllm.exceptions import VLLMValidationError
from vllm.inputs import TokensPrompt
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.reasoning.gemma4_reasoning_parser import Gemma4ReasoningParser
from vllm.renderers.hf import HfRenderer
from vllm.renderers.mistral import MistralRenderer
from vllm.tokenizers import get_tokenizer
Expand Down Expand Up @@ -2094,3 +2098,43 @@ def test_function_is_none(self):
assert tc.type == "function"
assert tc.function.name is None
assert tc.function.arguments == '{"data": "value"}'


class _Gemma4ReasoningTokenizer:
def get_vocab(self) -> dict[str, int]:
return {
"<|channel>": 100,
"<channel|>": 101,
}

def decode(self, ids, skip_special_tokens: bool = False) -> str:
if ids != [100, 200, 101, 300]:
raise AssertionError(f"unexpected ids: {ids!r}")
if skip_special_tokens:
return "thought\nThe user is asking...\n2 + 2 = 4"
return "<|channel>thought\nThe user is asking...\n<channel|>2 + 2 = 4"


def test_gemma4_non_streaming_reasoning_uses_token_ids():
tokenizer = _Gemma4ReasoningTokenizer()
parser = Gemma4ReasoningParser(tokenizer)
output = CompletionOutput(
index=0,
text="thought\nThe user is asking...\n2 + 2 = 4",
token_ids=[100, 200, 101, 300],
cumulative_logprob=0.0,
logprobs=None,
)
request = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "What is 2+2?"}],
)

parser_input = _get_reasoning_parser_input_text(output, tokenizer, parser)
reasoning, content = parser.extract_reasoning(parser_input, request=request)

assert parser_input == (
"<|channel>thought\nThe user is asking...\n<channel|>2 + 2 = 4"
)
assert reasoning == "The user is asking...\n"
assert content == "2 + 2 = 4"
23 changes: 22 additions & 1 deletion vllm/entrypoints/openai/chat_completion/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@
logger = init_logger(__name__)


def _get_reasoning_parser_input_text(
output: CompletionOutput,
tokenizer: TokenizerLike,
reasoning_parser: ReasoningParser,
) -> str:
"""Build non-streaming parser input without losing reasoning markers."""
start_token_id = getattr(reasoning_parser, "start_token_id", None)
end_token_id = getattr(reasoning_parser, "end_token_id", None)
if output.token_ids and (
(start_token_id is not None and start_token_id in output.token_ids)
or (end_token_id is not None and end_token_id in output.token_ids)
):
return tokenizer.decode(output.token_ids, skip_special_tokens=False)
return output.text


class OpenAIServingChat(OpenAIServing):
def __init__(
self,
Expand Down Expand Up @@ -1375,8 +1391,13 @@ async def chat_completion_full_generator(
if reasoning_parser:
# If the reasoning parser is enabled,
# tool calls are extracted exclusively from the content.
parser_input_text = _get_reasoning_parser_input_text(
output,
tokenizer,
reasoning_parser,
)
reasoning, content = reasoning_parser.extract_reasoning(
output.text, request=request
parser_input_text, request=request
)
if not request.include_reasoning:
reasoning = None
Expand Down
Loading