diff --git a/docs/en/llm/api_server_reasoning.md b/docs/en/llm/api_server_reasoning.md index 88c475c480..67b73f5789 100644 --- a/docs/en/llm/api_server_reasoning.md +++ b/docs/en/llm/api_server_reasoning.md @@ -1,12 +1,12 @@ # Reasoning Outputs -For models that support reasoning capabilities, such as [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1), LMDeploy supports parsing the reasoning results in the service and separately records the reasoning content using `reasoning_content`. +For models that support reasoning capabilities, such as [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1), LMDeploy can parse reasoning outputs on the server side and expose them via `reasoning_content`. ## Examples ### DeepSeek R1 -We can start the DeepSeek R1 model's api_server service just like launching other models. The difference is that we need to specify --reasoning-parser\` parameter. +We can start DeepSeek R1's `api_server` like other models, but we need to specify the `--reasoning-parser` argument. ``` lmdeploy serve api_server deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek-r1 @@ -44,62 +44,49 @@ print("content:", content) ## Custom parser -You only need to add a similar parser class in `lmdeploy/serve/openai/reasoning_parser/reasoning_parser.py`. +Built-in reasoning parser names include: -```python -# import the required packages -from typing import Sequence, Union, Tuple, Optional +- `qwen-qwq` +- `qwen3` +- `intern-s1` +- `deepseek-r1` +- `deepseek-v3` +- `gpt-oss` + +### Notes + +- `deepseek-v3`: starts in reasoning mode only when `enable_thinking=True`. + When `enable_thinking` is `None` (default), output is usually plain content without a reasoning segment. +- `gpt-oss`: parses OpenAI Harmony channels: + - `final` -> `content` + - `analysis` -> `reasoning_content` + - `commentary` with `functions.*` recipient -> `tool_calls` + +### Add a custom parser + +Add a parser class under `lmdeploy/serve/openai/reasoning_parser/` and register it with `ReasoningParserManager`. +```python from lmdeploy.serve.openai.reasoning_parser import ( - ReasoningParser, ReasoningParserManager) -from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, - DeltaMessage) + ReasoningParser, ReasoningParserManager +) -# define a reasoning parser and register it to lmdeploy -# the name list in register_module can be used -# in --reasoning-parser. @ReasoningParserManager.register_module(["example"]) class ExampleParser(ReasoningParser): - def __init__(self, tokenizer: object): - super().__init__(tokenizer) - - def extract_reasoning_content_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - ) -> Union[DeltaMessage, None]: - """ - Instance method that should be implemented for extracting reasoning - from an incomplete response; for use when handling reasoning calls and - streaming. Has to be an instance method because it requires state - - the current tokens/diffs, but also the information about what has - previously been parsed and extracted (see constructor) - """ - - def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest - ) -> Tuple[Optional[str], Optional[str]]: - """ - Extract reasoning content from a complete model-generated string. - - Used for non-streaming responses where we have the entire model response - available before sending to the client. - - Args: - model_output (str): The model-generated string to extract reasoning content from. - request (ChatCompletionRequest): he request object that was used to generate the model_output. - - Returns: - reasoning_content (str | None): The reasoning content. - final_output (str | None): The content. - """ + def __init__(self, tokenizer: object, **kwargs): + super().__init__(tokenizer, **kwargs) + + def get_reasoning_open_tag(self) -> str | None: + return "" + + def get_reasoning_close_tag(self) -> str | None: + return "" + + def starts_in_reasoning_mode(self) -> bool: + return True ``` -Similarly, the command to start the service becomes: +Then start the service with: ``` lmdeploy serve api_server $model_path --reasoning-parser example diff --git a/docs/zh_cn/llm/api_server_reasoning.md b/docs/zh_cn/llm/api_server_reasoning.md index 4860cd1553..9cf54941ce 100644 --- a/docs/zh_cn/llm/api_server_reasoning.md +++ b/docs/zh_cn/llm/api_server_reasoning.md @@ -1,14 +1,12 @@ # Reasoning Outputs -对于支持推理能力的模型,比如 [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1),LMDeploy 支持在服务中将推理的结果解析出来,并单独用 -reasoning_content 记录推理内容。 +对于支持推理能力的模型,比如 [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1),LMDeploy 支持在服务端解析推理结果,并通过 `reasoning_content` 单独返回推理内容。 ## 使用示例 ### DeepSeek R1 -我们可以像启动其他模型的 api_server 服务一样启动 DeepSeek R1 的模型,只是不同的是,我们需要指定 `--reasoning-parser`。 -在 `--reasoning-parser` 传参里,我们需要指定具体的 parser。 +我们可以像启动其他模型一样启动 DeepSeek R1 的 `api_server`,但需要额外指定 `--reasoning-parser` 参数。 ``` lmdeploy serve api_server deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek-r1 @@ -46,62 +44,49 @@ print("content:", content) ## 自定义 parser -只需要在 `lmdeploy/serve/openai/reasoning_parser/reasoning_parser.py` 中添加一个类似的 parser 类即可。 +内置的 reasoning parser 名称包括: -```python -# import the required packages -from typing import Sequence, Union, Tuple, Optional +- `qwen-qwq` +- `qwen3` +- `intern-s1` +- `deepseek-r1` +- `deepseek-v3` +- `gpt-oss` + +### 说明 + +- `deepseek-v3`:仅当 `enable_thinking=True` 时,才会从推理模式开始解析。 + 当 `enable_thinking` 为 `None`(默认)时,通常不会出现推理段,输出为普通内容。 +- `gpt-oss`:基于 OpenAI Harmony channel 解析: + - `final` -> `content` + - `analysis` -> `reasoning_content` + - `commentary` 且 `recipient` 为 `functions.*` -> `tool_calls` + +### 添加自定义 parser + +在 `lmdeploy/serve/openai/reasoning_parser/` 目录下新增 parser 类,并通过 `ReasoningParserManager` 注册。 +```python from lmdeploy.serve.openai.reasoning_parser import ( - ReasoningParser, ReasoningParserManager) -from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, - DeltaMessage) + ReasoningParser, ReasoningParserManager +) -# define a reasoning parser and register it to lmdeploy -# the name list in register_module can be used -# in --reasoning-parser. @ReasoningParserManager.register_module(["example"]) class ExampleParser(ReasoningParser): - def __init__(self, tokenizer: object): - super().__init__(tokenizer) - - def extract_reasoning_content_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - ) -> Union[DeltaMessage, None]: - """ - Instance method that should be implemented for extracting reasoning - from an incomplete response; for use when handling reasoning calls and - streaming. Has to be an instance method because it requires state - - the current tokens/diffs, but also the information about what has - previously been parsed and extracted (see constructor) - """ - - def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest - ) -> Tuple[Optional[str], Optional[str]]: - """ - Extract reasoning content from a complete model-generated string. - - Used for non-streaming responses where we have the entire model response - available before sending to the client. - - Args: - model_output (str): The model-generated string to extract reasoning content from. - request (ChatCompletionRequest): he request object that was used to generate the model_output. - - Returns: - reasoning_content (str | None): The reasoning content. - final_output (str | None): The content. - """ + def __init__(self, tokenizer: object, **kwargs): + super().__init__(tokenizer, **kwargs) + + def get_reasoning_open_tag(self) -> str | None: + return "" + + def get_reasoning_close_tag(self) -> str | None: + return "" + + def starts_in_reasoning_mode(self) -> bool: + return True ``` -类似的,启动服务的命令就变成了: +然后通过以下命令启动服务: ``` lmdeploy serve api_server $model_path --reasoning-parser example diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 70dea1a535..ab44bdd498 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -462,18 +462,20 @@ def chat_template(parser): @staticmethod def reasoning_parser(parser): """Add reasoning parser to parser.""" - from lmdeploy.serve.openai.reasoning_parser import ReasoningParserManager + legacy_names = ['qwen-qwq', 'intern-s1', 'deepseek-r1'] + from lmdeploy.serve.parsers.reasoning_parser import ReasoningParserManager return parser.add_argument( '--reasoning-parser', type=str, default=None, - help=f'The registered reasoning parser name from {ReasoningParserManager.module_dict.keys()}. ' + help=f'The registered reasoning parser name: {ReasoningParserManager.module_dict.keys()}. ' + f'Legacy names: {legacy_names}. ' 'Default to None.') @staticmethod def tool_call_parser(parser): """Add tool call parser to parser.""" - from lmdeploy.serve.openai.tool_parser import ToolParserManager + from lmdeploy.serve.parsers.tool_parser import ToolParserManager return parser.add_argument( '--tool-call-parser', diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 5e84ee4221..e50ee0f81f 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. # yapf: disable +from __future__ import annotations + import asyncio -import copy import json import os import re @@ -10,7 +11,12 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import Literal +from typing import TYPE_CHECKING, Literal + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase + + from lmdeploy.serve.parsers import ResponseParser import uvicorn from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status @@ -40,7 +46,6 @@ MigrationRequest, ) from lmdeploy.serve.core import AsyncEngine -from lmdeploy.serve.openai.harmony_utils import GptOssChatParser from lmdeploy.serve.openai.protocol import ( AbortRequest, ChatCompletionRequest, @@ -74,10 +79,7 @@ UpdateParamsRequest, UsageInfo, ) -from lmdeploy.serve.openai.reasoning_parser.reasoning_parser import ReasoningParser, ReasoningParserManager -from lmdeploy.serve.openai.tool_parser.tool_parser import ToolParser, ToolParserManager from lmdeploy.serve.utils.server_utils import validate_json_request -from lmdeploy.tokenizer import DetokenizeState, Tokenizer from lmdeploy.utils import get_logger # yapf: enable @@ -92,12 +94,9 @@ class VariableInterface: # following are for registering to proxy server proxy_url: str | None = None api_server_url: str | None = None - # following are for reasoning parsers - reasoning_parser: ReasoningParser | None = None - # following is for tool parsers - tool_parser: ToolParser | None = None allow_terminate_by_client: bool = False enable_abort_handling: bool = False + response_parser_cls: type[ResponseParser] | None = None @staticmethod def get_session(session_id: int) -> int: @@ -180,72 +179,13 @@ def always_success(req, server_context): return None -def _create_completion_logprobs(tokenizer: Tokenizer, - token_ids: list[int] | None = None, - logprobs: list[dict[int, float]] | None = None, - skip_special_tokens: bool = True, - offset: int = 0, - all_token_ids: list[int] | None = None, - state: DetokenizeState = None, - spaces_between_special_tokens: bool = True): - """Create openai LogProbs for completion. - - Args: - tokenizer (Tokenizer): tokenizer. - token_ids (list[int]): output token ids. - logprobs (list[dict[int, float]]): the top logprobs for each output - position. - skip_special_tokens (bool): Whether or not to remove special tokens - in the decoding. Default to be True. - offset (int): text offset. - all_token_ids (int): the history output token ids. - state (DetokenizeState): tokenizer decode state. - spaces_between_special_tokens (bool): Whether or not to add spaces - around special tokens. The behavior of Fast tokenizers is to have - this to False. This is setup to True in slow tokenizers. - """ - if logprobs is None or len(logprobs) == 0: - return None, None, None, None - - if all_token_ids is None: - all_token_ids = [] - if state is None: - state = DetokenizeState() - - out_logprobs = LogProbs() - out_logprobs.top_logprobs = [] - for token_id, tops in zip(token_ids, logprobs): - out_logprobs.text_offset.append(offset) - out_logprobs.token_logprobs.append(tops[token_id]) - - res = {} - out_state = None - for top_id, prob in tops.items(): - response, _state = tokenizer.detokenize_incrementally( - all_token_ids + [top_id], - copy.deepcopy(state), - skip_special_tokens=skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens) - res[response] = prob - if top_id == token_id: - out_state = _state - offset += len(response) - out_logprobs.tokens.append(response) - - out_logprobs.top_logprobs.append(res) - state = out_state - all_token_ids.append(token_id) - - return out_logprobs, offset, all_token_ids, state - - -def _create_chat_completion_logprobs(tokenizer: Tokenizer, +def _create_chat_completion_logprobs(tokenizer: PreTrainedTokenizerBase, token_ids: list[int] | None = None, logprobs: list[dict[int, float]] | None = None): """Create openai LogProbs for chat.completion. Args: - tokenizer (Tokenizer): tokenizer. + tokenizer (PreTrainedTokenizerBase): tokenizer. token_ids (list[int]): output token ids. logprobs (list[dict[int, float]]): the top logprobs for each output position. @@ -259,7 +199,7 @@ def _create_chat_completion_logprobs(tokenizer: Tokenizer, for token_id, tops in zip(token_ids, logprobs): item = ChatCompletionTokenLogprob(token='', bytes=[], logprob=0.0, top_logprobs=[]) for top_id, prob in tops.items(): - token = tokenizer.model.model.convert_ids_to_tokens(top_id) + token = tokenizer.convert_ids_to_tokens(top_id) if isinstance(token, bytes): _bytes = list(token) token = token.decode('utf-8', errors='backslashreplace') @@ -295,7 +235,8 @@ async def terminate(): # modified from https://github.com/vllm-project/vllm/blob/v0.5.4/vllm/entrypoints/openai/logits_processors.py#L51 # noqa -def logit_bias_logits_processor(logit_bias: dict[int, float] | dict[str, float], tokenizer) -> LogitsProcessor: +def logit_bias_logits_processor(logit_bias: dict[int, float] | dict[str, float], + tokenizer: PreTrainedTokenizerBase) -> LogitsProcessor: try: # Convert token_id to integer # Clamp the bias between -100 and 100 per OpenAI API spec @@ -409,8 +350,6 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque error_check_ret = check_request(request) if error_check_ret is not None: return error_check_ret - if VariableInterface.tool_parser is not None: - request = VariableInterface.tool_parser.adjust_request(request) session = VariableInterface.get_session(request.session_id) json_request = await raw_request.json() @@ -426,31 +365,27 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque adapter_name = model_name # got a adapter name request_id = str(session.session_id) created_time = int(time.time()) - gpt_oss_parser = None - if VariableInterface.async_engine.arch == 'GptOssForCausalLM': - gpt_oss_parser = GptOssChatParser() if isinstance(request.stop, str): request.stop = [request.stop] + tokenizer = VariableInterface.async_engine.tokenizer.model.model gen_logprobs, logits_processors = None, None if request.logprobs and request.top_logprobs: gen_logprobs = request.top_logprobs - response_format = None - if request.response_format and request.response_format.type != 'text': - response_format = request.response_format.model_dump() - if request.logit_bias is not None: try: logits_processors = [ - logit_bias_logits_processor(request.logit_bias, VariableInterface.async_engine.tokenizer.model) + logit_bias_logits_processor(request.logit_bias, tokenizer) ] except Exception as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) random_seed = request.seed if request.seed else None max_new_tokens = (request.max_completion_tokens if request.max_completion_tokens else request.max_tokens) - + response_format = None + if request.response_format and request.response_format.type != 'text': + response_format = request.response_format.model_dump() gen_config = GenerationConfig( max_new_tokens=max_new_tokens, do_sample=True, @@ -473,13 +408,17 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque with_cache=with_cache, preserve_cache=preserve_cache, ) + parser_cls = VariableInterface.response_parser_cls + response_parser = parser_cls(request=request, tokenizer=tokenizer) + # request might be adjusted by tool parser + request = response_parser.request tools = None - if request.tools and request.tool_choice != 'none': + if request.tools: + arch = VariableInterface.async_engine.arch gen_config.skip_special_tokens = False - # internlm2 only uses contents inside function regardless of 'type' if not isinstance(request.tool_choice, str): - if gpt_oss_parser: + if arch == 'GptOssForCausalLM': tools = [ item.model_dump() for item in request.tools if item.function.name == request.tool_choice.function.name @@ -490,10 +429,11 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque if item.function.name == request.tool_choice.function.name ] else: - if gpt_oss_parser: + if arch == 'GptOssForCausalLM': tools = [item.model_dump() for item in request.tools] else: tools = [item.function.model_dump() for item in request.tools] + # text completion for string input do_preprocess = False if isinstance(request.messages, str) else request.do_preprocess chat_template_kwargs = request.chat_template_kwargs or {} @@ -504,7 +444,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque chat_template_kwargs['enable_thinking'] = request.enable_thinking else: logger.warning('`enable_thinking` in `chat_template_kwargs` will override the value in request.') - enable_thinking = chat_template_kwargs.get('enable_thinking', None) + result_generator = VariableInterface.async_engine.generate( request.messages, session, @@ -541,18 +481,11 @@ def create_stream_response_json(index: int, return response_json async def completion_stream_generator() -> AsyncGenerator[str, None]: - previous_text = '' - current_text = '' - previous_token_ids = [] - current_token_ids = [] - delta_token_ids = [] - has_parser = VariableInterface.tool_parser is not None or VariableInterface.reasoning_parser is not None streaming_tools = False async for res in result_generator: logprobs, usage = None, None if gen_logprobs and res.logprobs: - logprobs = _create_chat_completion_logprobs(VariableInterface.async_engine.tokenizer, res.token_ids, - res.logprobs) + logprobs = _create_chat_completion_logprobs(tokenizer, res.token_ids, res.logprobs) # Only stream chunk `usage` in the final chunk according to OpenAI API spec if (res.finish_reason and request.stream_options and request.stream_options.include_usage): total_tokens = sum([res.input_token_len, res.generate_token_len]) @@ -561,50 +494,30 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: completion_tokens=res.generate_token_len, total_tokens=total_tokens, ) - delta_token_ids = res.token_ids if res.token_ids is not None else [] - if gpt_oss_parser: - delta_message = gpt_oss_parser.parse_streaming(res.token_ids) - if res.finish_reason == 'stop' and len(delta_message.tool_calls) > 0: + delta_message, tool_emitted = response_parser.stream_chunk( + res.response, + delta_token_ids + ) + if tool_emitted: + streaming_tools = True + + if (request.tool_choice != 'none' and response_parser.tool_parser is not None): + if res.finish_reason == 'stop' and streaming_tools is True: res.finish_reason = 'tool_calls' - else: - delta_message = DeltaMessage(role='assistant', content=res.response) - if has_parser: - current_text = current_text + res.response - current_token_ids = current_token_ids + delta_token_ids - if request.tool_choice != 'none' and VariableInterface.tool_parser is not None: - if res.finish_reason == 'stop' and streaming_tools is True: - res.finish_reason = 'tool_calls' - tool_delta = VariableInterface.tool_parser.extract_tool_calls_streaming( - previous_text=previous_text, - current_text=current_text, - delta_text=delta_message.content, - previous_token_ids=previous_token_ids, - current_token_ids=current_token_ids, - delta_token_ids=delta_token_ids, - request=request) - if tool_delta is not None: - delta_message.tool_calls = tool_delta.tool_calls - delta_message.content = tool_delta.content - if isinstance(tool_delta.tool_calls, list) and len(tool_delta.tool_calls): - streaming_tools = True - elif (request.tool_choice != 'none' and request.tools is not None - and VariableInterface.tool_parser is None): + elif request.tool_choice != 'none' and request.tools is not None: + if parser_cls.tool_parser_cls is None: logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.') - if VariableInterface.reasoning_parser is not None and enable_thinking is not False: - reasoning_delta = VariableInterface.reasoning_parser.extract_reasoning_content_streaming( - previous_text=previous_text, - current_text=current_text, - delta_text=delta_message.content or '', - previous_token_ids=previous_token_ids, - current_token_ids=current_token_ids, - delta_token_ids=delta_token_ids) - if reasoning_delta is not None: - delta_message.reasoning_content = reasoning_delta.reasoning_content - delta_message.content = reasoning_delta.content - if has_parser: - previous_text = current_text - previous_token_ids = current_token_ids + + # The parser may intentionally suppress no-op chunks by returning + # ``None``. Keep them suppressed unless this is a visible terminal + # frame (finish/usage/logprobs), where OpenAI-style streams still + # expect a delta object. + if delta_message is None: + if res.finish_reason is None and usage is None and logprobs is None: + continue + delta_message = DeltaMessage(role='assistant') + if request.return_token_ids: delta_message.gen_tokens = delta_token_ids response_json = create_stream_response_json(index=0, @@ -643,39 +556,31 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: cache_block_ids.append(res.cache_block_ids) remote_token_ids.append(res.token_ids) - if gpt_oss_parser: - message = gpt_oss_parser.parse_full(final_token_ids) - if final_res.finish_reason == 'stop' and len(message.tool_calls) > 0: - final_res.finish_reason = 'tool_calls' - else: - tool_calls = None - reasoning_content = None - if request.tool_choice != 'none' and VariableInterface.tool_parser is not None: - try: - tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request) - text, tool_calls = tool_call_info.content, tool_call_info.tool_calls - if isinstance(tool_calls, list) and len(tool_calls): - if final_res.finish_reason == 'stop': - final_res.finish_reason = 'tool_calls' - - except Exception as e: - logger.error(f'Failed to parse {text}. Exception: {e}.') - return create_error_response(HTTPStatus.BAD_REQUEST, 'Failed to parse fc related info to json format!') - elif request.tool_choice != 'none' and request.tools is not None and VariableInterface.tool_parser is None: - logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.') + tool_calls = None + reasoning_content = None + + try: + text, tool_calls, reasoning_content = response_parser.parse_complete( + text) + if isinstance(tool_calls, list) and len(tool_calls): + if final_res.finish_reason == 'stop': + final_res.finish_reason = 'tool_calls' - if VariableInterface.reasoning_parser is not None and enable_thinking is not False: - reasoning_content, text = VariableInterface.reasoning_parser.extract_reasoning_content(text, request) + except Exception as e: + logger.error(f'Failed to parse {text}. Exception: {e}.') + return create_error_response(HTTPStatus.BAD_REQUEST, 'Failed to parse fc related info to json format!') + if request.tool_choice != 'none' and request.tools is not None: + if ResponseParser.tool_parser_cls is None: + logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.') - message = ChatMessage(role='assistant', - content=text, - tool_calls=tool_calls, - reasoning_content=reasoning_content) + message = ChatMessage(role='assistant', + content=text, + tool_calls=tool_calls, + reasoning_content=reasoning_content) logprobs = None if gen_logprobs and len(final_logprobs): - logprobs = _create_chat_completion_logprobs(VariableInterface.async_engine.tokenizer, final_token_ids, - final_logprobs) + logprobs = _create_chat_completion_logprobs(tokenizer, final_token_ids, final_logprobs) assert final_res is not None choices = [] @@ -850,17 +755,11 @@ def create_stream_response_json(index: int, async def completion_stream_generator() -> AsyncGenerator[str, None]: # First chunk with role for generator in generators: - offset = 0 - all_token_ids = [] - state = DetokenizeState() async for res in generator: logprobs = None usage = None if request.logprobs and res.logprobs: - logprobs, offset, all_token_ids, state = _create_completion_logprobs( # noqa E501 - VariableInterface.async_engine.tokenizer, res.token_ids, res.logprobs, - gen_config.skip_special_tokens, offset, all_token_ids, state, - gen_config.spaces_between_special_tokens) + raise ValueError('logprobs is removed') # Only stream chunk `usage` in the final chunk according to OpenAI API spec if (res.finish_reason and request.stream_options and request.stream_options.include_usage): final_res = res @@ -916,14 +815,6 @@ async def _inner_call(i, generator): final_logprobs.extend(res.logprobs) logprobs = None - if request.logprobs and len(final_logprobs): - logprobs, _, _, _ = _create_completion_logprobs( - VariableInterface.async_engine.tokenizer, - final_token_ids, - final_logprobs, - gen_config.skip_special_tokens, - spaces_between_special_tokens=gen_config.spaces_between_special_tokens) - assert final_res is not None choice_data = CompletionResponseChoice(index=i, text=text, @@ -1328,27 +1219,18 @@ async def dispatch(self, request: Request, call_next): return response -def set_parsers(reasoning_parser: str | None = None, tool_parser: str | None = None): - """Set tool parser and reasoning parsers.""" - # set reasoning parser - if reasoning_parser is not None: - if reasoning_parser in ReasoningParserManager.module_dict: - tokenizer = VariableInterface.async_engine.tokenizer - VariableInterface.reasoning_parser = ReasoningParserManager.get(reasoning_parser)(tokenizer) - else: - raise ValueError( - f'The reasoning parser {reasoning_parser} is not in the parser list: {ReasoningParserManager.module_dict.keys()}' # noqa - ) - # set tool parsers - if tool_parser is not None: - if tool_parser in ToolParserManager.module_dict: - tokenizer = VariableInterface.async_engine.tokenizer - VariableInterface.tool_parser = ToolParserManager.get(tool_parser)(tokenizer) - else: - raise ValueError( - f'The reasoning parser {tool_parser} is not in the parser list: {ToolParserManager.module_dict.keys()}' # noqa - ) - +def set_parsers(reasoning_parser_name: str | None = None, tool_parser_name: str | None = None, **kwargs): + from lmdeploy.serve.parsers import ResponseParserManager + name = 'default' + arch = VariableInterface.async_engine.arch + if arch == 'GptOssForCausalLM': + name = 'gpt-oss' + cls = ResponseParserManager.get(name) + if cls is None: + raise ValueError(f'The response parser {name} is not in the parser list: ' + f'{ResponseParserManager.module_dict.keys()}') + cls.set_parsers(reasoning_parser_name=reasoning_parser_name, tool_parser_name=tool_parser_name) + VariableInterface.response_parser_cls = cls def mount_metrics(app: FastAPI, backend_config: PytorchEngineConfig | TurbomindEngineConfig): if not getattr(backend_config, 'enable_metrics', False): @@ -1466,7 +1348,7 @@ def serve(model_path: str, being printed in log. Default: Unlimited max_concurrent_requests: This refers to the number of concurrent requests that the server can handle. The server is designed to - process the engine’s tasks once the maximum number of concurrent + process the engine's tasks once the maximum number of concurrent requests is reached, regardless of any additional requests sent by clients concurrently during that time. Default to None. reasoning_parser (str): The reasoning parser name. @@ -1501,7 +1383,6 @@ def serve(model_path: str, max_log_len=max_log_len, speculative_config=speculative_config, **kwargs) - # set reasoning parser and tool parser set_parsers(reasoning_parser, tool_call_parser) # create FastAPI lifespan events diff --git a/lmdeploy/serve/openai/harmony_utils.py b/lmdeploy/serve/openai/harmony_utils.py deleted file mode 100644 index 2810725c0f..0000000000 --- a/lmdeploy/serve/openai/harmony_utils.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# Modified from https://github.com/vllm-project/vllm/blob/v0.10.2rc1/vllm/entrypoints/harmony_utils.py - -import shortuuid -from openai_harmony import HarmonyEncodingName, Role, StreamableParser, load_harmony_encoding - -from lmdeploy.serve.openai.protocol import ( - ChatMessage, - DeltaFunctionCall, - DeltaMessage, - DeltaToolCall, - FunctionCall, - ToolCall, -) - -_harmony_encoding = None - - -def get_encoding(): - global _harmony_encoding - if _harmony_encoding is None: - _harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) - return _harmony_encoding - - -def get_streamable_parser_for_assistant() -> 'StreamableParser': - return StreamableParser(get_encoding(), role=Role.ASSISTANT) - - -class GptOssChatParser: - - def __init__(self): - self.parser = get_streamable_parser_for_assistant() - - def parse_streaming(self, tokens: list[int]) -> DeltaMessage: - parser = self.parser - delta_message = DeltaMessage(role='assistant') - content = '' - reasoning_content = '' - tool_calls = [] - delta_tool_call = None - for token in tokens: - prev_recipient = parser.current_recipient - parser.process(token) - cur_channel = parser.current_channel - cur_recipient = parser.current_recipient - delta_text = parser.last_content_delta or '' - if cur_channel == 'final': - content += delta_text - elif cur_channel == 'analysis': - reasoning_content += delta_text - elif cur_channel == 'commentary' and cur_recipient and cur_recipient.startswith('functions.'): - base_index = 0 - for msg in parser.messages: - if msg.channel == 'commentary' and msg.recipient and msg.recipient.startswith('functions.'): - base_index += 1 - if prev_recipient != cur_recipient: - if delta_tool_call is not None: - tool_calls.append(delta_tool_call) - tool_name = cur_recipient.split('functions.', 1)[1] - delta_tool_call = DeltaToolCall(id=f'chatcmpl-tool-{shortuuid.random()}', - type='function', - index=base_index, - function=DeltaFunctionCall(name=tool_name, arguments='')) - elif delta_text: - # Continuing the same tool call. Ensure we don't duplicate the - # very first delta string in this chunk. Previously we initialized - # with arguments=delta_text and then appended again, causing - # duplicated content like "locationlocation". - if delta_tool_call is None: - # We are in the middle of a tool call carried over from the - # previous chunk. Initialize an empty arguments buffer. - delta_tool_call = DeltaToolCall(index=base_index, function=DeltaFunctionCall(arguments='')) - delta_tool_call.function.arguments += delta_text - - if delta_tool_call: - tool_calls.append(delta_tool_call) - - delta_message.content = content if content else None - delta_message.reasoning_content = reasoning_content if reasoning_content else None - delta_message.tool_calls = tool_calls - return delta_message - - def parse_full(self, tokens: list[int]) -> ChatMessage: - delta_message = self.parse_streaming(tokens) - tool_calls = [] - for delta_tool_call in delta_message.tool_calls: - function = FunctionCall(**delta_tool_call.function.model_dump()) - tool_calls.append(ToolCall(id=delta_tool_call.id, type=delta_tool_call.type, function=function)) - chat_message = ChatMessage(role='assistant', - content=delta_message.content, - tool_calls=tool_calls, - reasoning_content=delta_message.reasoning_content) - return chat_message diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index cf4a398ea5..296f3f69e1 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -188,7 +188,7 @@ class ExtractedToolCallInformation(BaseModel): # indicate if tools were called tools_called: bool # extracted tool calls - tool_calls: list[ToolCall] + tool_calls: list[ToolCall] | None = None # content - per OpenAI spec, content AND tool calls can be returned rarely # But some models will do this intentionally content: str | None = None @@ -253,7 +253,7 @@ class DeltaFunctionCall(BaseModel): # a tool call delta where everything is optional class DeltaToolCall(BaseModel): id: str = Field(default_factory=lambda: f'chatcmpl-tool-{shortuuid.random()}') - type: Literal['function'] = 'function' + type: Literal['function'] | None = 'function' index: int function: DeltaFunctionCall | None = None @@ -264,7 +264,7 @@ class DeltaMessage(BaseModel): content: str | None = None reasoning_content: str | None = None gen_tokens: list[int] | None = None - tool_calls: list[DeltaToolCall] = Field(default_factory=list) + tool_calls: list[DeltaToolCall] | None = None class ChatCompletionResponseStreamChoice(BaseModel): diff --git a/lmdeploy/serve/openai/reasoning_parser/__init__.py b/lmdeploy/serve/openai/reasoning_parser/__init__.py deleted file mode 100644 index 09d621a252..0000000000 --- a/lmdeploy/serve/openai/reasoning_parser/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser -from .qwen_qwq_reasoning_parser import QwenQwQReasoningParser -from .reasoning_parser import ReasoningParser, ReasoningParserManager - -__all__ = ['ReasoningParser', 'ReasoningParserManager', 'DeepSeekR1ReasoningParser', 'QwenQwQReasoningParser'] diff --git a/lmdeploy/serve/openai/reasoning_parser/deepseek_r1_reasoning_parser.py b/lmdeploy/serve/openai/reasoning_parser/deepseek_r1_reasoning_parser.py deleted file mode 100644 index d2392648e4..0000000000 --- a/lmdeploy/serve/openai/reasoning_parser/deepseek_r1_reasoning_parser.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# modified from https://github.com/vllm-project/vllm/tree/v0.7.3/vllm/entrypoints/openai/reasoning_parsers -import re -from collections.abc import Sequence - -from lmdeploy.serve.openai.protocol import ChatCompletionRequest, DeltaMessage - -from .reasoning_parser import ReasoningParser, ReasoningParserManager - - -@ReasoningParserManager.register_module(name='deepseek-r1') -class DeepSeekR1ReasoningParser(ReasoningParser): - """Reasoning parser for DeepSeek R1 model. - - The DeepSeek R1 model uses ... tokens to denote reasoning text. This parser extracts the reasoning - content from the model output. - """ - - def __init__(self, tokenizer: object): - super().__init__(tokenizer) - self.think_start_token = '' - self.think_end_token = '' - - self.reasoning_regex = re.compile(rf'{self.think_start_token}(.*?){self.think_end_token}', re.DOTALL) - - if not self.model_tokenizer: - raise ValueError('The model tokenizer must be passed to the ReasoningParser ' - 'constructor during construction.') - - self.think_start_token_id = self.vocab.get(self.think_start_token) - self.think_end_token_id = self.vocab.get(self.think_end_token) - if (self.think_start_token_id is None or self.think_end_token_id is None): - raise RuntimeError('DeepSeek R1 reasoning parser could not locate think start/end ' - 'tokens in the tokenizer!') - - def extract_reasoning_content_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - **kwargs, - ) -> DeltaMessage | None: - """Instance method that should be implemented for extracting reasoning - from an incomplete response; for use when handling reasoning calls and - streaming. - - Has to be an instance method because it requires state - the current tokens/diffs, but also the information - about what has previously been parsed and extracted (see constructor) - """ - # Skip single special tokens - if len(delta_token_ids) == 1: - if delta_token_ids[0] == self.think_end_token_id: - return DeltaMessage(content='') - elif delta_token_ids[0] == self.think_start_token_id: - return None - - # Check if is present in previous or delta. - # Keep compatibility with models that don't generate tokens. - if self.think_start_token_id in previous_token_ids: - if self.think_end_token_id in delta_token_ids: - # in previous, in delta, - # extract reasoning content - end_index = delta_text.find(self.think_end_token) - reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None) - elif self.think_end_token_id in previous_token_ids: - # in previous, in previous, - return DeltaMessage(content=delta_text) - else: - # in previous, no in previous or delta, - # reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - elif self.think_start_token_id in delta_token_ids: - if self.think_end_token_id in delta_token_ids: - # in delta, in delta, extract reasoning content - start_index = delta_text.find(self.think_start_token) - end_index = delta_text.find(self.think_end_token) - reasoning_content = delta_text[start_index + len(self.think_start_token):end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None) - else: - # in delta, no in delta, - # reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - else: - # No in previous or delta, also need to check for . - # Because the model may have generated without - # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f - if self.think_end_token_id in delta_token_ids: - # in delta with more tokens, - # extract reasoning content and content - end_index = delta_text.find(self.think_end_token) - reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None) - elif self.think_end_token_id in previous_token_ids: - # in previous, thinking content ends - return DeltaMessage(content=delta_text) - else: - # no in previous or delta, reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - - def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest, - **kwargs) -> tuple[str | None, str | None]: - """Extract reasoning content from a complete model-generated string. - - Used for non-streaming responses where we have the entire model response - available before sending to the client. - - Args: - model_output (str): The model-generated string to extract reasoning content from. - request (ChatCompletionRequest): he request object that was used to generate the model_output. - - Returns: - reasoning_content (str | None): The reasoning content. - final_output (str | None): The content. - """ - # DeepSeek R1 doesn't generate now. - # Thus we assume the reasoning content is always at the start. - # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f - if self.think_end_token not in model_output: - return model_output, None - else: - # Add a start token if it's missing to keep compatibility. - if self.think_start_token not in model_output: - model_output = f'{self.think_start_token}{model_output}' - # Use a regex to find the reasoning content - reasoning_content = self.reasoning_regex.findall(model_output)[0] - - end_index = len(f'{self.think_start_token}{reasoning_content}{self.think_end_token}') - final_output = model_output[end_index:] - - if len(final_output) == 0: - return reasoning_content, None - - return reasoning_content, final_output diff --git a/lmdeploy/serve/openai/reasoning_parser/qwen_qwq_reasoning_parser.py b/lmdeploy/serve/openai/reasoning_parser/qwen_qwq_reasoning_parser.py deleted file mode 100644 index 63f35d76e6..0000000000 --- a/lmdeploy/serve/openai/reasoning_parser/qwen_qwq_reasoning_parser.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import re -from collections.abc import Sequence - -from lmdeploy.serve.openai.protocol import ChatCompletionRequest, DeltaMessage - -from .reasoning_parser import ReasoningParser, ReasoningParserManager - - -@ReasoningParserManager.register_module(name=['qwen-qwq', 'intern-s1']) -class QwenQwQReasoningParser(ReasoningParser): - """Reasoning parser for Qwen QwQ model. - - The Qwen QwQ model uses ... tokens to denote reasoning text. This parser extracts the reasoning - content from the model output. - """ - - def __init__(self, tokenizer: object): - super().__init__(tokenizer) - self.think_start_token = '' - self.think_end_token = '' - - self.reasoning_regex = re.compile(rf'{self.think_start_token}(.*?){self.think_end_token}', re.DOTALL) - - if not self.model_tokenizer: - raise ValueError('The model tokenizer must be passed to the ReasoningParser ' - 'constructor during construction.') - - def extract_reasoning_content_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - **kwargs, - ) -> DeltaMessage | None: - """Instance method that should be implemented for extracting reasoning - from an incomplete response; for use when handling reasoning calls and - streaming. - - Has to be an instance method because it requires state - the current tokens/diffs, but also the information - about what has previously been parsed and extracted (see constructor) - """ - # Skip single special tokens - if delta_text == self.think_end_token or delta_text == self.think_start_token: - return DeltaMessage(content='') - - # Check if is present in previous or delta. - # Keep compatibility with models that don't generate tokens. - if self.think_start_token in previous_text: - if self.think_end_token in delta_text: - # in previous, in delta, - # extract reasoning content - end_index = delta_text.find(self.think_end_token) - reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None) - elif self.think_end_token in previous_text: - # in previous, in previous, - return DeltaMessage(content=delta_text) - else: - # in previous, no in previous or delta, - # reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - elif self.think_start_token in delta_text: - if self.think_end_token in delta_text: - # in delta, in delta, extract reasoning content - start_index = delta_text.find(self.think_start_token) - end_index = delta_text.find(self.think_end_token) - reasoning_content = delta_text[start_index + len(self.think_start_token):end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None) - else: - # in delta, no in delta, - # reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - else: - # No in previous or delta, also need to check for . - # Because the model may have generated without - # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f - if self.think_end_token in delta_text: - # in delta with more tokens, - # extract reasoning content and content - end_index = delta_text.find(self.think_end_token) - reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None) - elif self.think_end_token in previous_text: - # in previous, thinking content ends - return DeltaMessage(content=delta_text) - else: - # no in previous or delta, reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - - def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest, - **kwargs) -> tuple[str | None, str | None]: - """Extract reasoning content from a complete model-generated string. - - Used for non-streaming responses where we have the entire model response - available before sending to the client. - - Args: - model_output (str): The model-generated string to extract reasoning content from. - request (ChatCompletionRequest): he request object that was used to generate the model_output. - - Returns: - reasoning_content (str | None): The reasoning content. - final_output (str | None): The content. - """ - # DeepSeek R1 doesn't generate now. - # Thus we assume the reasoning content is always at the start. - # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f - if self.think_end_token not in model_output: - # for qwen3 model, the reasoning content is wrapped by xml tags - return None, model_output - # Add a start token if it's missing to keep compatibility. - if self.think_start_token not in model_output: - model_output = f'{self.think_start_token}{model_output}' - # Use a regex to find the reasoning content - reasoning_content = self.reasoning_regex.findall(model_output)[0] - - end_index = len(f'{self.think_start_token}{reasoning_content}{self.think_end_token}') - final_output = model_output[end_index:] - if reasoning_content.startswith('\n'): - reasoning_content = reasoning_content[1:] - if reasoning_content.endswith('\n'): - reasoning_content = reasoning_content[:-1] - - if len(final_output) == 0: - return reasoning_content, None - - return reasoning_content, final_output diff --git a/lmdeploy/serve/openai/reasoning_parser/reasoning_parser.py b/lmdeploy/serve/openai/reasoning_parser/reasoning_parser.py deleted file mode 100644 index 7abb62069d..0000000000 --- a/lmdeploy/serve/openai/reasoning_parser/reasoning_parser.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# modified from https://github.com/vllm-project/vllm/tree/v0.7.3/vllm/entrypoints/openai/reasoning_parsers -from collections.abc import Sequence -from functools import cached_property - -from mmengine import Registry - -from lmdeploy.serve.openai.protocol import ChatCompletionRequest, DeltaMessage - -ReasoningParserManager = Registry('reasoning_parser', locations=['lmdeploy.serve.openai.reasoning_parser']) - - -class ReasoningParser: - - def __init__(self, tokenizer: object): - self.model_tokenizer = tokenizer - - @cached_property - def vocab(self) -> dict[str, int]: - # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab - # whereas all tokenizers have .get_vocab() - return self.model_tokenizer.get_vocab() - - def extract_reasoning_content_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - **kwargs, - ) -> DeltaMessage | None: - """Instance method that should be implemented for extracting reasoning - from an incomplete response; for use when handling reasoning calls and - streaming. - - Has to be an instance method because it requires state - the current tokens/diffs, but also the information - about what has previously been parsed and extracted (see constructor) - """ - raise NotImplementedError('ReasoningParser.extract_reasoning_content_streaming ' - 'has not been implemented!') - - def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest, - **kwargs) -> tuple[str | None, str | None]: - """Extract reasoning content from a complete model-generated string. - - Used for non-streaming responses where we have the entire model response - available before sending to the client. - - Args: - model_output (str): The model-generated string to extract reasoning content from. - request (ChatCompletionRequest): he request object that was used to generate the model_output. - - Returns: - reasoning_content (str | None): The reasoning content. - final_output (str | None): The content. - """ - raise NotImplementedError('ReasoningParser.extract_reasoning_content ' - 'has not been implemented!') diff --git a/lmdeploy/serve/openai/tool_parser/internlm2_parser.py b/lmdeploy/serve/openai/tool_parser/internlm2_parser.py deleted file mode 100644 index 89e2bb471e..0000000000 --- a/lmdeploy/serve/openai/tool_parser/internlm2_parser.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# modified from https://github.com/vllm-project/vllm/tree/v0.7.3/vllm/entrypoints/openai/tool_parsers -import json -from collections.abc import Sequence - -import partial_json_parser -import shortuuid -from partial_json_parser.core.options import Allow - -from lmdeploy.serve.openai.protocol import ( - ChatCompletionRequest, - DeltaFunctionCall, - DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, - ToolCall, -) -from lmdeploy.utils import get_logger - -from .tool_parser import ToolParser, ToolParserManager -from .utils import extract_intermediate_diff - -logger = get_logger('lmdeploy') - - -@ToolParserManager.register_module(['internlm', 'intern-s1']) -class Internlm2ToolParser(ToolParser): - - def __init__(self, tokenizer: object): - super().__init__(tokenizer) - self.position = 0 - - def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: - if request.tools and request.tool_choice != 'none': - # do not skip special tokens because internlm use the special - # tokens to indicated the start and end of the tool calls - # information. - request.skip_special_tokens = False - return request - - def get_argments(self, obj): - if 'parameters' in obj: - return obj.get('parameters') - elif 'arguments' in obj: - return obj.get('arguments') - return None - - def extract_tool_calls_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - request: ChatCompletionRequest, - ) -> DeltaMessage | None: - if '<|action_start|>' not in current_text: - self.position = len(current_text) - return DeltaMessage(content=delta_text) - # if the tool call is sended, return a empty delta message - # to make sure the finish_reason will be send correctly. - if self.current_tool_id > 0: - return DeltaMessage(content='') - - last_pos = self.position - if '<|action_start|><|plugin|>\n' not in current_text[last_pos:]: - return None - - new_delta = current_text[last_pos:] - text, action = new_delta.split('<|action_start|><|plugin|>\n') - - if len(text) > 0: - self.position = self.position + len(text) - return DeltaMessage(content=text) - - action = action.strip() - action = action.split('<|action_end|>'.strip())[0] - - # bit mask flags for partial JSON parsing. If the name hasn't been - # sent yet, don't allow sending - # an incomplete string since OpenAI only ever (as far as I have - # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR - - try: - parsable_arr = action - - # tool calls are generated in an object in inernlm2 - # it's not support parallel tool calls - try: - tool_call_arr: dict = partial_json_parser.loads(parsable_arr, flags) - except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') - return None - - # if the current tool name hasn't been sent, send if available - # - otherwise send nothing - if not self.current_tool_name_sent: - function_name = tool_call_arr.get('name') - if function_name: - self.current_tool_id = self.current_tool_id + 1 - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type='function', - id=f'chatcmpl-tool-{shortuuid.random()}', - function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True)) - ]) - self.current_tool_name_sent = True - self.streamed_args_for_tool.append('') - else: - delta = None - # now we know we're on the same tool call and we're streaming - # arguments - else: - prev_arguments = self.get_argments(self.prev_tool_call_arr[self.current_tool_id]) - cur_arguments = self.get_argments(tool_call_arr) - - # not arguments generated - if not cur_arguments and not prev_arguments: - delta = None - # will never happen - elif not cur_arguments and prev_arguments: - logger.error('INVARIANT - impossible to have arguments reset ' - 'mid-arguments') - delta = None - # first time to get parameters - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False) - - arguments_delta = cur_arguments_json[:cur_arguments_json.index(delta_text) + len(delta_text)] - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall(arguments=arguments_delta).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[self.current_tool_id] += arguments_delta - # both prev and cur parameters, send the increase parameters - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) - - argument_diff = extract_intermediate_diff(cur_args_json, prev_args_json) - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall(arguments=argument_diff).model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[self.current_tool_id] += argument_diff - - # check to see if the name is defined and has been sent. if so, - # stream the name - otherwise keep waiting - # finish by setting old and returning None as base case - tool_call_arr['arguments'] = self.get_argments(tool_call_arr) - self.prev_tool_call_arr = [tool_call_arr] - return delta - except Exception: - logger.exception('Error trying to handle streaming tool call.') - logger.debug('Skipping chunk as a result of tool streaming extraction ' - 'error') - return None - - def extract_tool_calls( - self, - model_output: str, - request: ChatCompletionRequest, - ) -> ExtractedToolCallInformation: - text = model_output - tools = request.tools - if '<|action_start|><|plugin|>' in text: - text, action = text.split('<|action_start|><|plugin|>') - action = action.split('<|action_end|>'.strip())[0] - action = action[action.find('{'):] - action_dict = json.loads(action) - name, parameters = action_dict['name'], json.dumps(action_dict.get('parameters', - action_dict.get('arguments', {})), - ensure_ascii=False) - - if not tools or name not in [t.function.name for t in tools]: - ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=text) - - tool_calls = [ToolCall(function=FunctionCall(name=name, arguments=parameters))] - return ExtractedToolCallInformation(tools_called=True, - tool_calls=tool_calls, - content=text if len(text) > 0 else None) - - return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=text) diff --git a/lmdeploy/serve/openai/tool_parser/llama3_parser.py b/lmdeploy/serve/openai/tool_parser/llama3_parser.py deleted file mode 100644 index 445cad312f..0000000000 --- a/lmdeploy/serve/openai/tool_parser/llama3_parser.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import json -import re -from collections.abc import Sequence - -import partial_json_parser -import shortuuid -from partial_json_parser.core.options import Allow - -from lmdeploy.serve.openai.protocol import ( - ChatCompletionRequest, - DeltaFunctionCall, - DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, - ToolCall, -) -from lmdeploy.utils import get_logger - -from .tool_parser import ToolParser, ToolParserManager -from .utils import find_common_prefix, is_complete_json, partial_json_loads - -logger = get_logger('lmdeploy') - - -@ToolParserManager.register_module('llama3') -class Llama3JsonToolParser(ToolParser): - """Tool call parser for Llama 3.1 models intended for use with the - examples/tool_chat_template_llama.jinja template. - - Used when --tool-call-parser llama3 are all set - """ - - def __init__(self, tokenizer: object): - super().__init__(tokenizer) - - # initialize properties used for state when parsing tool calls in - # streaming mode - self.prev_tool_call_arr: list[dict] = [] - self.current_tool_id: int = -1 - self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list - self.bot_token = '<|python_tag|>' - self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[0] - self.tool_call_regex = re.compile(r'\[{.*?}\]', re.DOTALL) - - def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: - """Extract the tool calls from a complete model response.""" - try: - # load the JSON, and then use it to build the Function and - # Tool Call - action, _ = model_output.split('') - parameters = action[action.find('{'):] - name = action.split('{')[0] - call_info_list = [(name, parameters)] - - tool_calls: list[ToolCall] = [ - ToolCall(type='function', function=FunctionCall(name=name, arguments=arguments)) - for name, arguments in call_info_list - ] - - # get any content before the tool call - ret = ExtractedToolCallInformation(tools_called=True, tool_calls=tool_calls, content=None) - return ret - - except Exception: - logger.exception('Error in extracting tool call from response.') - # return information to just treat the tool call as regular JSON - return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) - - def extract_tool_calls_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - request: ChatCompletionRequest, - ) -> DeltaMessage | None: - - if not (current_text.startswith(self.bot_token) or current_text.startswith('{')): - return DeltaMessage(content=delta_text) - - # bit mask flags for partial JSON parsing. If the name hasn't been - # sent yet, don't allow sending - # an incomplete string since OpenAI only ever (as far as I have - # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR - try: - tool_call_arr = [] - is_complete = [] - try: - # depending on the prompt format the Llama model may or may not - # prefix the output with the <|python_tag|> token - start_idx = len(self.bot_token) if current_text.startswith(self.bot_token) else 0 - while start_idx < len(current_text): - (obj, end_idx) = partial_json_loads(current_text[start_idx:], flags) - is_complete.append(is_complete_json(current_text[start_idx:start_idx + end_idx])) - start_idx += end_idx + len('; ') - # depending on the prompt Llama can use - # either arguments or parameters - if 'parameters' in obj: - assert 'arguments' not in obj, \ - 'model generated both parameters and arguments' - obj['arguments'] = obj['parameters'] - tool_call_arr.append(obj) - except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') - return None - - # select as the current tool call the one we're on the state at - current_tool_call: dict = tool_call_arr[self.current_tool_id] \ - if len(tool_call_arr) > 0 else {} - - # case -- if no tokens have been streamed for the tool, e.g. - # only the array brackets, stream nothing - if len(tool_call_arr) == 0: - return None - - # case: we are starting a new tool in the array - # -> array has > 0 length AND length has moved past cursor - elif (len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1): - - # if we're moving on to a new call, first make sure we - # haven't missed anything in the previous one that was - # auto-generated due to JSON completions, but wasn't - # streamed to the client yet. - if self.current_tool_id >= 0: - cur_arguments = current_tool_call.get('arguments') - if cur_arguments: - cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) - sent = len(self.streamed_args_for_tool[self.current_tool_id]) - argument_diff = cur_args_json[sent:] - - logger.debug('got arguments diff: %s', argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall(arguments=argument_diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[self.current_tool_id] += argument_diff - else: - delta = None - else: - delta = None - # re-set stuff pertaining to progress in the current tool - self.current_tool_id = len(tool_call_arr) - 1 - self.current_tool_name_sent = False - self.streamed_args_for_tool.append('') - logger.debug('starting on new tool %d', self.current_tool_id) - return delta - - # if the current tool name hasn't been sent, send if available - # - otherwise send nothing - elif not self.current_tool_name_sent: - function_name = current_tool_call.get('name') - if function_name: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type='function', - id=f'chatcmpl-tool-{shortuuid.random()}', - function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True)) - ]) - self.current_tool_name_sent = True - else: - delta = None - - # now we know we're on the same tool call and we're streaming - # arguments - else: - cur_arguments = current_tool_call.get('arguments') - delta = None - - if cur_arguments: - sent = len(self.streamed_args_for_tool[self.current_tool_id]) - cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) - prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get('arguments') - - argument_diff = None - if is_complete[self.current_tool_id]: - argument_diff = cur_args_json[sent:] - elif prev_arguments: - prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) - if cur_args_json != prev_args_json: - - prefix = find_common_prefix(prev_args_json, cur_args_json) - argument_diff = prefix[sent:] - - if argument_diff is not None: - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall(arguments=argument_diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[self.current_tool_id] += argument_diff - - self.prev_tool_call_arr = tool_call_arr - return delta - - except Exception: - logger.exception('Error trying to handle streaming tool call.') - logger.debug('Skipping chunk as a result of tool streaming extraction ' - 'error') - return None diff --git a/lmdeploy/serve/openai/tool_parser/qwen2d5_parser.py b/lmdeploy/serve/openai/tool_parser/qwen2d5_parser.py deleted file mode 100644 index eb87d1f97a..0000000000 --- a/lmdeploy/serve/openai/tool_parser/qwen2d5_parser.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import json -import re -from collections.abc import Sequence - -import partial_json_parser -import shortuuid -from partial_json_parser.core.options import Allow - -from lmdeploy.serve.openai.protocol import ( - ChatCompletionRequest, - DeltaFunctionCall, - DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, - ToolCall, -) -from lmdeploy.utils import get_logger - -from .tool_parser import ToolParser, ToolParserManager -from .utils import extract_intermediate_diff - -logger = get_logger('lmdeploy') - - -@ToolParserManager.register_module(['qwen2d5']) -class Qwen2d5ToolParser(ToolParser): - - def __init__(self, tokenizer: object): - super().__init__(tokenizer) - self.position = 0 - self.tool_start_token = '' - self.tool_end_token = '' - self.pattern = r'(.*?)' - - def get_argments(self, obj): - if 'parameters' in obj: - return obj.get('parameters') - elif 'arguments' in obj: - return obj.get('arguments') - return None - - def extract_tool_calls_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - request: ChatCompletionRequest, - ) -> DeltaMessage | None: - if self.tool_start_token not in current_text: - self.position = len(current_text) - return DeltaMessage(content=delta_text) - # if the tool call is sended, return a empty delta message - # to make sure the finish_reason will be send correctly. - if self.current_tool_id > 0: - return DeltaMessage(content='') - - last_pos = self.position - if self.tool_start_token not in current_text[last_pos:]: - return None - - new_delta = current_text[last_pos:] - text, action = new_delta.split(self.tool_start_token) - - if len(text) > 0: - self.position = self.position + len(text) - return DeltaMessage(content=text) - - action = action.strip() - action = action.split(self.tool_end_token.strip())[0] - - # bit mask flags for partial JSON parsing. If the name hasn't been - # sent yet, don't allow sending - # an incomplete string since OpenAI only ever (as far as I have - # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR - - try: - parsable_arr = action - - # tool calls are generated in an object in inernlm2 - # it's not support parallel tool calls - try: - tool_call_arr: dict = partial_json_parser.loads(parsable_arr, flags) - except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') - return None - - # if the current tool name hasn't been sent, send if available - # - otherwise send nothing - if not self.current_tool_name_sent: - function_name = tool_call_arr.get('name') - if function_name: - self.current_tool_id = self.current_tool_id + 1 - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type='function', - id=f'chatcmpl-tool-{shortuuid.random()}', - function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True)) - ]) - self.current_tool_name_sent = True - self.streamed_args_for_tool.append('') - else: - delta = None - # now we know we're on the same tool call and we're streaming - # arguments - else: - prev_arguments = self.get_argments(self.prev_tool_call_arr[self.current_tool_id]) - cur_arguments = self.get_argments(tool_call_arr) - - # not arguments generated - if not cur_arguments and not prev_arguments: - delta = None - # will never happen - elif not cur_arguments and prev_arguments: - logger.error('INVARIANT - impossible to have arguments reset ' - 'mid-arguments') - delta = None - # first time to get parameters - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False) - - arguments_delta = cur_arguments_json[:cur_arguments_json.index(delta_text) + len(delta_text)] - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall(arguments=arguments_delta).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[self.current_tool_id] += arguments_delta - # both prev and cur parameters, send the increase parameters - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) - - argument_diff = extract_intermediate_diff(cur_args_json, prev_args_json) - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall(arguments=argument_diff).model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[self.current_tool_id] += argument_diff - - # check to see if the name is defined and has been sent. if so, - # stream the name - otherwise keep waiting - # finish by setting old and returning None as base case - tool_call_arr['arguments'] = self.get_argments(tool_call_arr) - self.prev_tool_call_arr = [tool_call_arr] - return delta - except Exception: - logger.exception('Error trying to handle streaming tool call.') - logger.debug('Skipping chunk as a result of tool streaming extraction ' - 'error') - return None - - def extract_tool_calls( - self, - model_output: str, - request: ChatCompletionRequest, - ) -> ExtractedToolCallInformation: - text = model_output - if self.tool_start_token in text: - - # get tool_call in text - match_result_list = re.findall(self.pattern, text, re.DOTALL) - tool_calls = [] - for match_result in match_result_list: - action = json.loads(match_result) - name, arguments = action['name'], json.dumps(action['arguments'], ensure_ascii=False) - tool_calls.append(ToolCall(function=FunctionCall(name=name, arguments=arguments))) - - # get text outside of tags - if not text.startswith(''): - text = text[:text.find('')] - elif not text.endswith(''): - text = text[text.rfind('') + len(''):] - else: - text = '' - return ExtractedToolCallInformation(tools_called=True, - tool_calls=tool_calls, - content=text if len(text) > 0 else None) - - return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=text) diff --git a/lmdeploy/serve/openai/tool_parser/qwen3_parser.py b/lmdeploy/serve/openai/tool_parser/qwen3_parser.py deleted file mode 100644 index 4b04410461..0000000000 --- a/lmdeploy/serve/openai/tool_parser/qwen3_parser.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import json -import re -from collections.abc import Sequence -from dataclasses import dataclass - -import shortuuid - -from lmdeploy.serve.openai.protocol import ( - ChatCompletionRequest, - DeltaFunctionCall, - DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, - ToolCall, -) -from lmdeploy.utils import get_logger - -from .tool_parser import ToolParser, ToolParserManager - -logger = get_logger('lmdeploy') - - -@dataclass -class ParserState: - """Maintains the state of parsing during tool call extraction.""" - position: int = 0 # Current position in the text being parsed - current_index: int = -1 # Index of the current tool call - parsing_reasoning: bool = False # Whether currently parsing reasoning content - - id: str = '' # ID of the current tool call - - def reset_tool_call(self): - """Called when `` finish tag occurred.""" - self.id = '' - - -@ToolParserManager.register_module(['qwen', 'qwen3']) -class Qwen3ToolParser(ToolParser): - """Parser for Qwen3 model's tool call format. - - Handles the extraction of tool calls from Qwen3's output format, which uses XML-like tags for tool calls and - reasoning. - """ - - def __init__(self, tokenizer: object): - super().__init__(tokenizer) - self.tool_start_token = '' - self.tool_end_token = '' - self.tool_call_pat = re.compile(r'\n*(.*?)', re.DOTALL) - - def get_argments(self, obj): - """Extract arguments from tool call object, handling different formats. - - Supports both 'parameters' and 'arguments' keys in the tool call object. - """ - if 'parameters' in obj: - return obj.get('parameters') - elif 'arguments' in obj: - return obj.get('arguments') - return None - - def _split(self, parser_state: ParserState, parsing_content: str): - """Split content into tuple: (text_content, tool_content, has_tool_end) - - This method parses the model output and separates it into regular text, - and tool call content. - """ - # tool call - try: - start_idx = parsing_content.index(self.tool_start_token) - # move to the beginning of tool_start_token - parser_state.position += start_idx - except ValueError: - parser_state.position += len(parsing_content) - return parsing_content, '', False - try: - end_idx = parsing_content.index(self.tool_end_token) - except ValueError: - # position holds until tool_end_token is found - return parsing_content[:start_idx], '', False - # move position to the end of tool_end_token - parser_state.position += (end_idx - start_idx) + len(self.tool_end_token) - return parsing_content[:start_idx], parsing_content[start_idx + len(self.tool_start_token):end_idx], True - - def _parse_delta_tool_call(self, parser_state: ParserState, tool_content: str) -> DeltaToolCall | None: - """Parse tool content into a DeltaToolCall object. - - This method handles parsing tool calls only when it's a valid tool - """ - parsable_arr = tool_content.strip() - try: - tool_call_arr: dict = json.loads(parsable_arr) - except json.JSONDecodeError: - logger.debug('cannot parse into JSON yet') - return - - fcall = DeltaFunctionCall() - func_name = tool_call_arr.get('name') - if func_name: - fcall.name = func_name - args = self.get_argments(tool_call_arr) - if args and isinstance(args, dict): - fcall.arguments = json.dumps(args, ensure_ascii=False) - # Return None if no new information to send - if not fcall.name and not fcall.arguments: - return - if not parser_state.id: - # A new tool call parsed, allocate a new id & index - parser_state.id = f'chatcmpl-tool-{shortuuid.random()}' - parser_state.current_index += 1 - # Create and return the DeltaToolCall object - return DeltaToolCall( - id=parser_state.id, - index=parser_state.current_index, - function=fcall.model_dump(exclude_none=True), - ) - - def extract_tool_calls_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - request: ChatCompletionRequest, - ) -> DeltaMessage | None: - """Extract tool calls from streaming model output. - - This method processes incremental model output to extract tool calls, reasoning content, and regular text - content in a streaming fashion. It maintains parser state between calls to handle partial outputs. - """ - parser_state = getattr(request, '_tool_parser_state', None) - if parser_state is None: - parser_state = ParserState() - setattr(request, '_tool_parser_state', parser_state) - - # Split the new content into text and tool content - split_result = self._split(parser_state, current_text[parser_state.position:]) - text_content, tool_content, has_tool_end = split_result - delta = DeltaMessage() - - # Add each type of content to the delta message if present - if text_content: - delta.content = text_content - if tool_content: - # Parse tool content into a DeltaToolCall object - delta_tool_call = self._parse_delta_tool_call(parser_state, tool_content) - if delta_tool_call is not None: - delta.tool_calls = [delta_tool_call] - if has_tool_end: - parser_state.reset_tool_call() - return delta - - def extract_tool_calls( - self, - model_output: str, - request: ChatCompletionRequest, - ) -> ExtractedToolCallInformation: - """Extract tool calls from complete model output. - - This method processes the full model output to extract tool calls, reasoning content, and regular text content. - Unlike the streaming version, this processes the entire output at once. - """ - text = model_output - - # Extract tool calls (content inside tags) - buf = [] - scan_pos = 0 - tool_calls = [] - for idx, match in enumerate(self.tool_call_pat.finditer(text)): - buf.append(text[scan_pos:match.start()]) # Add text before the tag - scan_pos = match.end() - action = json.loads(match.group(1)) # Parse the tool call JSON - name, arguments = action['name'], json.dumps(action['arguments'], ensure_ascii=False) - tool_calls.append(ToolCall(function=FunctionCall(name=name, arguments=arguments))) - if scan_pos < len(text): - buf.append(text[scan_pos:]) # Add remaining text - text = ''.join(buf) # Reconstruct text without tags - - return ExtractedToolCallInformation( - content=text, - tool_calls=tool_calls, - tools_called=bool(tool_calls), - ) diff --git a/lmdeploy/serve/openai/tool_parser/qwen3coder_parser.py b/lmdeploy/serve/openai/tool_parser/qwen3coder_parser.py deleted file mode 100644 index 7b22716af0..0000000000 --- a/lmdeploy/serve/openai/tool_parser/qwen3coder_parser.py +++ /dev/null @@ -1,307 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import json -import re -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Any - -import shortuuid - -from lmdeploy.serve.openai.protocol import ( - ChatCompletionRequest, - DeltaFunctionCall, - DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, - ToolCall, -) -from lmdeploy.utils import get_logger - -from .tool_parser import ToolParser, ToolParserManager - -logger = get_logger('lmdeploy') - - -def _parse_tool_call_arguments_dict(arguments: Any) -> dict[str, Any] | None: - """Return dict-like tool arguments for Qwen3Coder request rendering.""" - if not isinstance(arguments, str): - return None - - try: - parsed_arguments = json.loads(arguments) - except (json.JSONDecodeError, TypeError): - return None - if isinstance(parsed_arguments, dict): - return parsed_arguments - return None - - -@dataclass -class ParserState: - """Maintains the state of parsing during tool call extraction.""" - position: int = 0 # Current position in the text being parsed - current_index: int = -1 # Index of the current tool call - - id: str = '' # ID of the current tool call - - def reset_tool_call(self): - """Called when `` finish tag occurred.""" - self.id = '' - - -@ToolParserManager.register_module(['qwen3coder']) -class Qwen3CoderToolParser(ToolParser): - """Parser for Qwen3 Coder model's tool call format. - - Handles the extraction of tool calls from Qwen3 Coder's output format, which uses purely XML tags for function names - and parameters, e.g., arg_value - - """ - - def __init__(self, tokenizer: object): - super().__init__(tokenizer) - self.tool_start_token = '' - self.tool_end_token = '' - self.func_prefix = '(.*?)', re.DOTALL) - - def _normalize_request_messages(self, messages: list[dict]) -> list[dict] | None: - """Return a render-safe copy of request messages when needed.""" - normalized_messages = None - - for msg_idx, message in enumerate(messages): - if not isinstance(message, dict) or message.get('role') != 'assistant': - continue - tool_calls = message.get('tool_calls') - if not isinstance(tool_calls, list): - continue - - normalized_tool_calls = None - for tool_idx, tool_call in enumerate(tool_calls): - if not isinstance(tool_call, dict): - continue - function = tool_call.get('function') - if not isinstance(function, dict) or isinstance(function.get('arguments'), dict): - continue - - parsed_arguments = _parse_tool_call_arguments_dict(function.get('arguments')) - if parsed_arguments is None: - continue - - if normalized_messages is None: - normalized_messages = list(messages) - if normalized_tool_calls is None: - normalized_tool_calls = list(tool_calls) - normalized_message = dict(message) - normalized_message['tool_calls'] = normalized_tool_calls - normalized_messages[msg_idx] = normalized_message - - normalized_function = dict(function) - normalized_function['arguments'] = parsed_arguments - - normalized_tool_call = dict(tool_call) - normalized_tool_call['function'] = normalized_function - normalized_tool_calls[tool_idx] = normalized_tool_call - - return normalized_messages - - def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: - messages = request.messages - if not isinstance(messages, list): - return request - - normalized_messages = self._normalize_request_messages(messages) - if normalized_messages is None: - return request - return request.model_copy(update={'messages': normalized_messages}) - - def _split(self, parser_state: ParserState, parsing_content: str) -> tuple[str, str, bool]: - """Split content into tuple: (text_content, tool_content, has_tool_end)""" - try: - start_idx = parsing_content.index(self.tool_start_token) - parser_state.position += start_idx - except ValueError: - parser_state.position += len(parsing_content) - return parsing_content, '', False - - try: - end_idx = parsing_content.index(self.tool_end_token) - except ValueError: - return parsing_content[:start_idx], parsing_content[start_idx:], False - - rem = end_idx - start_idx - parser_state.position += rem + len(self.tool_end_token) - return parsing_content[:start_idx], parsing_content[start_idx:end_idx + len(self.tool_end_token)], True - - def _extract_params(self, content: str) -> tuple[str | None, dict[str, Any], bool]: - """Parse XML tool content into components.""" - content = content.replace(self.tool_start_token, '').replace(self.tool_end_token, '').strip() - - func_name = None - func_start = content.find(self.func_prefix) - if func_start != -1: - name_start = func_start + len(self.func_prefix) - terminators = [idx for idx in (content.find('>', name_start), content.find('\n', name_start)) if idx != -1] - if terminators: - func_name = content[name_start:min(terminators)].strip() - - args_dict = {} - search_idx = 0 - while True: - param_start = content.find(self.param_prefix, search_idx) - if param_start == -1: - break - - name_start = param_start + len(self.param_prefix) - terminators = [idx for idx in (content.find('>', name_start), content.find('\n', name_start)) if idx != -1] - if not terminators: - break - - name_end = min(terminators) - param_name = content[name_start:name_end].strip() - - val_start = name_end + 1 - val_end = content.find(self.param_end_token, val_start) - if val_end == -1: - break - - param_val_str = content[val_start:val_end].strip() - - if param_val_str.lower() == 'null': - val = None - elif param_val_str.lower() == 'true': - val = True - elif param_val_str.lower() == 'false': - val = False - else: - try: - val = json.loads(param_val_str) - except json.JSONDecodeError: - val = param_val_str - args_dict[param_name] = val - search_idx = val_end + len(self.param_end_token) - - is_func_closed = self.func_end_token in content - return func_name, args_dict, is_func_closed - - def extract_tool_calls_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - request: ChatCompletionRequest, - ) -> DeltaMessage | None: - - parser_state = getattr(request, '_tool_parser_state', None) - if parser_state is None: - parser_state = ParserState() - setattr(request, '_tool_parser_state', parser_state) - - split_result = self._split(parser_state, current_text[parser_state.position:]) - text_content, tool_content, has_tool_end = split_result - - delta = DeltaMessage() - if text_content: - delta.content = text_content - - if tool_content: - if not parser_state.id: - parser_state.id = f'chatcmpl-tool-{shortuuid.random()}' - parser_state.current_index += 1 - parser_state.has_emitted_name = False - parser_state.has_emitted_json_start = False - parser_state.json_closed = False - parser_state.emitted_params = set() - - func_name, args_dict, is_func_closed = self._extract_params(tool_content) - - fcall_delta = DeltaFunctionCall() - has_updates = False - - if func_name and not getattr(parser_state, 'has_emitted_name', False): - fcall_delta.name = func_name - parser_state.has_emitted_name = True - has_updates = True - - json_fragments = [] - if not getattr(parser_state, 'has_emitted_json_start', False): - if args_dict or is_func_closed: - json_fragments.append('{') - parser_state.has_emitted_json_start = True - - for k, v in args_dict.items(): - if k not in parser_state.emitted_params: - prefix = ', ' if len(parser_state.emitted_params) > 0 else '' - serialized = json.dumps(v, ensure_ascii=False) - json_fragments.append(f'{prefix}"{k}": {serialized}') - parser_state.emitted_params.add(k) - - if is_func_closed and not getattr(parser_state, 'json_closed', False): - if getattr(parser_state, 'has_emitted_json_start', False): - json_fragments.append('}') - parser_state.json_closed = True - - joined_fragments = ''.join(json_fragments) - if joined_fragments: - fcall_delta.arguments = joined_fragments - has_updates = True - - if has_updates: - parsed_delta = DeltaToolCall( - id=parser_state.id, - index=parser_state.current_index, - function=fcall_delta, - ) - delta.tool_calls = [parsed_delta] - - if has_tool_end: - parser_state.reset_tool_call() - # Prepare for the next tool call - if hasattr(parser_state, 'has_emitted_name'): - delattr(parser_state, 'has_emitted_name') - delattr(parser_state, 'has_emitted_json_start') - delattr(parser_state, 'json_closed') - delattr(parser_state, 'emitted_params') - - return delta - - def extract_tool_calls( - self, - model_output: str, - request: ChatCompletionRequest, - ) -> ExtractedToolCallInformation: - text = model_output - buf = [] - scan_pos = 0 - tool_calls = [] - - for idx, match in enumerate(self.tool_call_pat.finditer(text)): - buf.append(text[scan_pos:match.start()]) - scan_pos = match.end() - - tool_content = match.group(1) - func_name, args_dict, _ = self._extract_params(tool_content) - - if func_name: - tool_calls.append( - ToolCall(function=FunctionCall( - name=func_name, arguments=json.dumps(args_dict, ensure_ascii=False) if args_dict else '{}'))) - - if scan_pos < len(text): - buf.append(text[scan_pos:]) - - text = ''.join(buf) - - return ExtractedToolCallInformation( - content=text, - tool_calls=tool_calls, - tools_called=bool(tool_calls), - ) diff --git a/lmdeploy/serve/openai/tool_parser/tool_parser.py b/lmdeploy/serve/openai/tool_parser/tool_parser.py deleted file mode 100644 index f919d33ef7..0000000000 --- a/lmdeploy/serve/openai/tool_parser/tool_parser.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# modified from https://github.com/vllm-project/vllm/tree/v0.7.3/vllm/entrypoints/openai/tool_parsers -from collections.abc import Sequence -from functools import cached_property - -from mmengine import Registry - -from lmdeploy.serve.openai.protocol import ChatCompletionRequest, DeltaMessage, ExtractedToolCallInformation -from lmdeploy.utils import get_logger - -logger = get_logger('lmdeploy') -ToolParserManager = Registry('tool_parser', locations=['lmdeploy.serve.openai.tool_parser']) - - -class ToolParser: - """Abstract ToolParser class that should not be used directly. - - Provided properties and methods should be used in derived classes. - """ - - def __init__(self, tokenizer: object): - self.prev_tool_call_arr: list[dict] = [] - # the index of the tool call that is currently being parsed - self.current_tool_id: int = -1 - self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: list[str] = [] - - self.model_tokenizer = tokenizer - - @cached_property - def vocab(self) -> dict[str, int]: - # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab - # whereas all tokenizers have .get_vocab() - return self.model_tokenizer.get_vocab() - - def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: - """Static method that used to adjust the request parameters.""" - return request - - def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: - """Static method that should be implemented for extracting tool calls - from a complete model-generated string. - - Used for non-streaming responses where we have the entire model response available before sending to the client. - Static because it's stateless. - """ - raise NotImplementedError('AbstractToolParser.extract_tool_calls has not been implemented!') - - def extract_tool_calls_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - request: ChatCompletionRequest, - ) -> DeltaMessage | None: - """Instance method that should be implemented for extracting tool calls - from an incomplete response; for use when handling tool calls and - streaming. - - Has to be an instance method because it requires state - the current tokens/diffs, but also the information - about what has previously been parsed and extracted (see constructor) - """ - raise NotImplementedError('AbstractToolParser.extract_tool_calls_streaming has not been ' - 'implemented!') diff --git a/lmdeploy/serve/openai/tool_parser/utils.py b/lmdeploy/serve/openai/tool_parser/utils.py deleted file mode 100644 index bee4728d8c..0000000000 --- a/lmdeploy/serve/openai/tool_parser/utils.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# Copied from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/openai/tool_parsers/utils.py - -import json -from json import JSONDecodeError, JSONDecoder -from typing import Any - -import partial_json_parser -from partial_json_parser.core.options import Allow - - -def find_common_prefix(s1: str, s2: str) -> str: - """Finds a common prefix that is shared between two strings, if there is - one. Order of arguments is NOT important. - - This function is provided as a UTILITY for extracting information from JSON generated by partial_json_parser, to - help in ensuring that the right tokens are returned in streaming, so that close-quotes, close-brackets and close- - braces are not returned prematurely. - - e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '{"fruit": "ap' - """ - prefix = '' - min_length = min(len(s1), len(s2)) - for i in range(0, min_length): - if s1[i] == s2[i]: - prefix += s1[i] - else: - break - return prefix - - -def find_common_suffix(s1: str, s2: str) -> str: - """Finds a common suffix shared between two strings, if there is one. Order - of arguments is NOT important. Stops when the suffix ends OR it hits an - alphanumeric character. - - e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' - """ - suffix = '' - min_length = min(len(s1), len(s2)) - for i in range(1, min_length + 1): - if s1[-i] == s2[-i] and not s1[-i].isalnum(): - suffix = s1[-i] + suffix - else: - break - return suffix - - -def extract_intermediate_diff(curr: str, old: str) -> str: - """Given two strings, extract the difference in the middle between two - strings that are known to have a common prefix and/or suffix. - - This function is provided as a UTILITY for extracting information from JSON - generated by partial_json_parser, to help in ensuring that the right tokens - are returned in streaming, so that close-quotes, close-brackets and - close-braces are not returned prematurely. The order of arguments IS - important - the new version of the partially-parsed JSON must be the first - argument, and the secnod argument must be from the previous generation. - - What it returns, is tokens that should be streamed to the client. - - e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') - -> 'ple' - """ - suffix = find_common_suffix(curr, old) - - old = old[::-1].replace(suffix[::-1], '', 1)[::-1] - prefix = find_common_prefix(curr, old) - diff = curr - if len(suffix): - diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] - - if len(prefix): - # replace the prefix only once in case it's mirrored - diff = diff.replace(prefix, '', 1) - - return diff - - -def find_all_indices(string: str, substring: str) -> list[int]: - """Find all (starting) indices of a substring in a given string. - - Useful for tool call extraction - """ - indices = [] - index = -1 - while True: - index = string.find(substring, index + 1) - if index == -1: - break - indices.append(index) - return indices - - -# partial_json_parser doesn't support extra data and -# JSONDecorder.raw_decode doesn't support partial JSON -def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]: - try: - return (partial_json_parser.loads(input_str, flags), len(input_str)) - except JSONDecodeError as e: - if 'Extra data' in e.msg: - dec = JSONDecoder() - return dec.raw_decode(input_str) - raise - - -def is_complete_json(input_str: str) -> bool: - try: - json.loads(input_str) - return True - except JSONDecodeError: - return False - - -def consume_space(i: int, s: str) -> int: - while i < len(s) and s[i].isspace(): - i += 1 - return i diff --git a/lmdeploy/serve/parsers/__init__.py b/lmdeploy/serve/parsers/__init__.py new file mode 100644 index 0000000000..7a0a26fcfb --- /dev/null +++ b/lmdeploy/serve/parsers/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# registers ResponseParser 'gpt-oss', None if openai_harmony unavailable +from .gpt_oss_response_parser import GptOssResponseParser +from .response_parser import ResponseParser, ResponseParserManager + +__all__ = ['ResponseParser', 'ResponseParserManager', 'GptOssResponseParser'] diff --git a/lmdeploy/serve/parsers/_openai_harmony.py b/lmdeploy/serve/parsers/_openai_harmony.py new file mode 100644 index 0000000000..17d04f41c8 --- /dev/null +++ b/lmdeploy/serve/parsers/_openai_harmony.py @@ -0,0 +1,200 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""GPT-OSS Harmony response parser; only imported when openai_harmony is +available.""" +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +import shortuuid +from openai_harmony import HarmonyEncodingName, Role, StreamableParser, load_harmony_encoding + +from lmdeploy.serve.openai.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + FunctionCall, + ToolCall, +) + +from .response_parser import ResponseParser, ResponseParserManager + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase + + from lmdeploy.serve.openai.protocol import ChatCompletionRequest + +_harmony_encoding = None + + +def get_encoding(): + global _harmony_encoding + if _harmony_encoding is None: + _harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + return _harmony_encoding + + +@ResponseParserManager.register_module('gpt-oss') +class GptOssResponseParser(ResponseParser): + """Harmony stream parser for GPT-OSS (assistant role).""" + + def __init__(self, request: ChatCompletionRequest, tokenizer: PreTrainedTokenizerBase): + self.request = request + self.model_tokenizer = tokenizer + self.parser = StreamableParser(get_encoding(), role=Role.ASSISTANT) + self._seen_any = False + self._next_tool_index = 0 + self._active_tool_id: str | None = None + self._active_tool_index: int | None = None + self._active_tool_name: str | None = None + self.tool_parser = object() # API server checks `is not None` for tool support. + + def stream_chunk(self, delta_text: str, delta_token_ids: list[int], **kwargs) -> tuple[DeltaMessage | None, bool]: + if ( + not delta_text + and not delta_token_ids + and not self._seen_any + ): + return DeltaMessage(role='assistant', content=''), False + + self._seen_any = True + + # Harmony parsing is token-based. If a backend emits text without ids, + # degrade gracefully as plain content. + if not delta_token_ids: + if not delta_text: + return None, False + return DeltaMessage(role='assistant', content=delta_text), False + + content = '' + reasoning = '' + tool_deltas: list[DeltaToolCall] = [] + + for token in delta_token_ids: + prev_recipient = self.parser.current_recipient + self.parser.process(token) + cur_channel = self.parser.current_channel + cur_recipient = self.parser.current_recipient + token_delta = self.parser.last_content_delta or '' + + tool_name = self._extract_tool_name(cur_recipient) + prev_tool_name = self._extract_tool_name(prev_recipient) + is_tool_channel = cur_channel in ('commentary', 'analysis') + + if is_tool_channel and tool_name: + # Start of a new tool call. + if tool_name != prev_tool_name: + self._active_tool_id = f'chatcmpl-tool-{shortuuid.random()}' + self._active_tool_index = self._next_tool_index + self._active_tool_name = tool_name + self._next_tool_index += 1 + tool_deltas.append( + DeltaToolCall( + id=self._active_tool_id, + index=self._active_tool_index, + type='function', + function=DeltaFunctionCall(name=tool_name), + )) + + if token_delta and self._active_tool_id is not None and self._active_tool_index is not None: + tool_deltas.append( + DeltaToolCall( + id=self._active_tool_id, + index=self._active_tool_index, + type=None, + function=DeltaFunctionCall(arguments=token_delta), + )) + continue + + # Normal textual channels. + if cur_channel == 'final': + content += token_delta + elif cur_channel == 'analysis': + reasoning += token_delta + + if not content and not reasoning and not tool_deltas: + return None, False + + return DeltaMessage( + role='assistant', + content=content or None, + reasoning_content=reasoning or None, + tool_calls=tool_deltas or None, + ), bool(tool_deltas) + + def parse_complete(self, text: str, **kwargs) -> tuple[str, list | None, str | None]: + token_ids = kwargs.get('token_ids') or [] + if not token_ids: + # Non-streaming path may not always pass token ids yet. + return text if text else None, None, None + + content = '' + reasoning = '' + + calls: list[dict] = [] + active: dict | None = None + + for token in token_ids: + prev_recipient = self.parser.current_recipient + self.parser.process(token) + cur_channel = self.parser.current_channel + cur_recipient = self.parser.current_recipient + token_delta = self.parser.last_content_delta or '' + + tool_name = self._extract_tool_name(cur_recipient) + prev_tool_name = self._extract_tool_name(prev_recipient) + is_tool_channel = cur_channel in ('commentary', 'analysis') + + if is_tool_channel and tool_name: + if tool_name != prev_tool_name: + if active is not None: + calls.append(active) + active = { + 'id': f'chatcmpl-tool-{shortuuid.random()}', + 'name': tool_name, + 'arguments': '', + } + if token_delta and active is not None: + active['arguments'] += token_delta + continue + + if active is not None: + calls.append(active) + active = None + + if cur_channel == 'final': + content += token_delta + elif cur_channel == 'analysis': + reasoning += token_delta + + if active is not None: + calls.append(active) + + tool_calls = [ + ToolCall( + id=call['id'], + type='function', + function=FunctionCall(name=call['name'], arguments=call['arguments']), + ) for call in calls + ] or None + + return content or None, tool_calls, reasoning or None + + @staticmethod + def _extract_tool_name(recipient: str | None) -> str | None: + """Extract function name from recipient string. + + Handles malformed sequences like + ``functions.bash<|channel|>commentary`` by stripping harmony tags. + """ + if not recipient: + return None + idx = recipient.find('functions.') + if idx < 0: + return None + clean = recipient[idx:] + clean = clean.split('<|channel|>', 1)[0] + clean = re.split(r'[\s<|]', clean, maxsplit=1)[0] + if not clean.startswith('functions.') or len(clean) <= len('functions.'): + return None + return clean.split('functions.', 1)[1] diff --git a/lmdeploy/serve/parsers/gpt_oss_response_parser.py b/lmdeploy/serve/parsers/gpt_oss_response_parser.py new file mode 100644 index 0000000000..c67e50a2e1 --- /dev/null +++ b/lmdeploy/serve/parsers/gpt_oss_response_parser.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""GPT-OSS response parser entry; loads Harmony implementation only when +openai_harmony is installed.""" +from __future__ import annotations + +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + +_OPENAI_HARMONY_AVAILABLE = False +try: + import openai_harmony # noqa: F401 +except ImportError as e: + logger.warning( + 'openai_harmony import failed (%s). Install openai_harmony for GPT-OSS Harmony response ' + 'parsing; without it the server uses the default response parser for GPT-OSS models.', + e, + ) +else: + _OPENAI_HARMONY_AVAILABLE = True + +GptOssResponseParser = None # type: ignore[misc, assignment] +if _OPENAI_HARMONY_AVAILABLE: + pass # type: ignore[import-untyped] diff --git a/lmdeploy/serve/parsers/harmony_utils.py b/lmdeploy/serve/parsers/harmony_utils.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/lmdeploy/serve/parsers/harmony_utils.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/lmdeploy/serve/parsers/reasoning_parser/__init__.py b/lmdeploy/serve/parsers/reasoning_parser/__init__.py new file mode 100644 index 0000000000..6a4079450c --- /dev/null +++ b/lmdeploy/serve/parsers/reasoning_parser/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser +from .reasoning_parser import ReasoningParser, ReasoningParserManager + +__all__ = [ + 'ReasoningParser', + 'ReasoningParserManager', + 'DeepSeekV3ReasoningParser', +] diff --git a/lmdeploy/serve/parsers/reasoning_parser/deepseek_v3_reasoning_parser.py b/lmdeploy/serve/parsers/reasoning_parser/deepseek_v3_reasoning_parser.py new file mode 100644 index 0000000000..93bb6e64c9 --- /dev/null +++ b/lmdeploy/serve/parsers/reasoning_parser/deepseek_v3_reasoning_parser.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .reasoning_parser import ReasoningParser, ReasoningParserManager + + +@ReasoningParserManager.register_module('deepseek-v3') +class DeepSeekV3ReasoningParser(ReasoningParser): + """Reasoning parser for DeepSeek-V3. + + DeepSeek-V3 differs from qwen3 default behavior: + - ``enable_thinking=True``: model can emit reasoning stream (...) + - ``enable_thinking=None``: model typically emits no reasoning part + """ + + def __init__(self, tokenizer: object, **kwargs): + super().__init__(tokenizer, **kwargs) + self.enable_thinking = kwargs.get('enable_thinking', None) + + def starts_in_reasoning_mode(self) -> bool: + # Enter reasoning mode only when explicitly requested. + return self.enable_thinking is True diff --git a/lmdeploy/serve/parsers/reasoning_parser/reasoning_parser.py b/lmdeploy/serve/parsers/reasoning_parser/reasoning_parser.py new file mode 100644 index 0000000000..c7efe0fc10 --- /dev/null +++ b/lmdeploy/serve/parsers/reasoning_parser/reasoning_parser.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from mmengine import Registry + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase + +ReasoningParserManager = Registry('reasoning_parser', locations=['lmdeploy.serve.parsers.reasoning_parser']) + + +@ReasoningParserManager.register_module(name='default') +class ReasoningParser: + """Unified reasoning parser for all ``--reasoning-parser`` options.""" + + def __init__(self, tokenizer: PreTrainedTokenizerBase, **kwargs): + + vocab = tokenizer.get_vocab() + start_token_id = vocab.get(self.get_reasoning_open_tag()) + end_token_id = vocab.get(self.get_reasoning_close_tag()) + if start_token_id is None or end_token_id is None: + raise RuntimeError(f'{self.__class__.__name__} reasoning parser could not get ' + 'reasoning tokens from the tokenizer!') + + def get_reasoning_open_tag(self) -> str | None: + return '' + + def get_reasoning_close_tag(self) -> str | None: + return '' + + def starts_in_reasoning_mode(self) -> bool: + return True diff --git a/lmdeploy/serve/parsers/response_parser.py b/lmdeploy/serve/parsers/response_parser.py new file mode 100644 index 0000000000..82daf6f595 --- /dev/null +++ b/lmdeploy/serve/parsers/response_parser.py @@ -0,0 +1,501 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Unified profile-driven streaming parser for reasoning/content/tool calls.""" +from __future__ import annotations + +from abc import abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar + +from mmengine import Registry + +from lmdeploy.serve.openai.protocol import DeltaMessage +from lmdeploy.utils import get_logger + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase + + from lmdeploy.serve.openai.protocol import ChatCompletionRequest, DeltaToolCall, ToolCall + + from .reasoning_parser import ReasoningParser + from .tool_parser import ToolParser + +logger = get_logger('lmdeploy') + +ResponseParserManager = Registry('response_parser', locations=['lmdeploy.serve.parsers.response_parser']) + + +class ResponseParser: + @classmethod + def set_parsers(cls, reasoning_parser_name: str | None = None, tool_parser_name: str | None = None) -> None: + pass + + def __init__(self, request: ChatCompletionRequest, tokenizer: PreTrainedTokenizerBase): + pass + + @abstractmethod + def stream_chunk(self, delta_text: str, delta_token_ids: list[int], **kwargs) -> tuple[DeltaMessage | None, bool]: + raise NotImplementedError + + @abstractmethod + def parse_complete(self, text: str, **kwargs) -> tuple[str, list | None, str | None]: + raise NotImplementedError + +@dataclass +class ProtocolProfile: + """Protocol tags and startup mode used by :class:`ResponseParser`. + + ``starts_in_reasoning_mode`` decides the initial parse mode before any tags are seen. + In ResponseParser, it controls whether the parser treats the beginning of generation as: + - reasoning (MODE_REASONING) -> text goes to reasoning_content, or + - plain (MODE_PLAIN) -> text goes to normal content. + Practically: + - If parser has reasoning support, ``enable_thinking`` is not False, and + ``starts_in_reasoning_mode=True``, first chunks are parsed as reasoning until ````. + - Otherwise it starts in plain mode and only enters reasoning when it sees ````. + It is only a profile default and can be customized by concrete reasoning + parsers (for example DeepSeek-V3). + """ + + reasoning_open_tag: str | None = None + reasoning_close_tag: str | None = None + tool_open_tag: str | None = None + tool_close_tag: str | None = None + tool_payload_format: str = 'json' + starts_in_reasoning_mode: bool = True + + +@dataclass +class _QueuedDelta: + delta: DeltaMessage + tool_calls_emitted: bool = False + + +@ResponseParserManager.register_module('default') +class BaseResponseParser(ResponseParser): + """The default response parser for streaming and complete assistant + responses. + + It separates model output into: + - plain assistant content + - reasoning content + - tool-call deltas + + Parsing is protocol/profile-driven and supports mixed chunks where one + ``delta_text`` may contain multiple segments (for example reasoning close + plus plain text plus tool open tag). + """ + + reasoning_parser_cls: ClassVar[type[ReasoningParser] | None] = None + tool_parser_cls: ClassVar[type[ToolParser] | None] = None + MODE_PLAIN: ClassVar[str] = 'plain' + MODE_REASONING: ClassVar[str] = 'reasoning' + MODE_TOOL: ClassVar[str] = 'tool' + + @classmethod + def set_parsers( + cls, + reasoning_parser_name: str | None = None, + tool_parser_name: str | None = None, + ) -> None: + """Configure reasoning/tool parser classes by registry name.""" + from .reasoning_parser import ReasoningParserManager + from .tool_parser import ToolParserManager + + legacy_reasoning_parser_names = ['qwen-qwq', 'intern-s1', 'deepseek-r1'] + if reasoning_parser_name in legacy_reasoning_parser_names: + logger.warning(f'The reasoning parser {reasoning_parser_name} is deprecated, ' + 'please use the default reasoning parser instead.') + reasoning_parser_name = 'default' + + if reasoning_parser_name is not None: + if reasoning_parser_name in ReasoningParserManager.module_dict: + cls.reasoning_parser_cls = ReasoningParserManager.get(reasoning_parser_name) + else: + raise ValueError(f'The reasoning parser {reasoning_parser_name} is not in the parser list: ' + f'{ReasoningParserManager.module_dict.keys()}') + + if tool_parser_name is not None: + if tool_parser_name in ToolParserManager.module_dict: + cls.tool_parser_cls = ToolParserManager.get(tool_parser_name) + else: + raise ValueError(f'The tool parser {tool_parser_name} is not in the parser list: ' + f'{ToolParserManager.module_dict.keys()}') + + @classmethod + def chat_template_kwargs_from_request(cls, request: ChatCompletionRequest) -> dict: + """Normalize parser-related template kwargs from the request. + + ``enable_thinking`` is a deprecated top-level field. This helper maps + it into ``chat_template_kwargs`` so downstream parser behavior can rely + on one normalized source. + """ + chat_template_kwargs = request.chat_template_kwargs or {} + if request.enable_thinking is not None: + logger.warning('`enable_thinking` will be deprecated in the future, ' + 'please use `chat_template_kwargs` instead.') + if chat_template_kwargs.get('enable_thinking') is None: + chat_template_kwargs['enable_thinking'] = request.enable_thinking + else: + logger.warning( + '`enable_thinking` in `chat_template_kwargs` will override the value in request.') + return chat_template_kwargs + + def __init__( + self, + request: ChatCompletionRequest, + tokenizer: PreTrainedTokenizerBase, + ): + rcls = type(self).reasoning_parser_cls + tcls = type(self).tool_parser_cls + self._kwargs = type(self).chat_template_kwargs_from_request(request) + self.enable_thinking: bool | None = self._kwargs.get('enable_thinking', None) + self.reasoning_parser: ReasoningParser | None = ( + rcls(tokenizer, **self._kwargs) if rcls else None + ) + self.tool_parser: ToolParser | None = ( + tcls(tokenizer) if tcls else None + ) + if self.tool_parser is not None: + self.request = self.tool_parser.adjust_request(request) + else: + self.request = request + self._accumulated_text = '' + + self.profile = self._build_profile() + if (self.reasoning_parser is not None and self.enable_thinking is not False + and self.profile.starts_in_reasoning_mode): + self._mode = self.MODE_REASONING + else: + self._mode = self.MODE_PLAIN + self._pending = '' + self._queued_deltas: list[_QueuedDelta] = [] + + def stream_chunk( + self, + delta_text: str, + delta_token_ids: list[int], + **kwargs, + ) -> tuple[DeltaMessage | None, bool]: + """Parse one streamed chunk into delta message channels. + + Args: + delta_text: New text fragment produced in this stream step. + delta_token_ids: Token ids corresponding to ``delta_text``. + + Returns: + ``(delta_message, tool_calls_emitted)`` where: + - ``delta_message`` is ``None`` when this step has no visible delta. + - ``tool_calls_emitted`` is ``True`` if at least one tool-call + delta is emitted in this step. + """ + # Special-case: some backends emit a leading empty delta (no text, no + # tokens) before any actual content. Tests treat this as a visible empty + # content delta. + if ( + not delta_text + and not delta_token_ids + and self._accumulated_text == '' + ): + return DeltaMessage(role='assistant', content=''), False + + if self.tool_parser is None and self.reasoning_parser is None: + return DeltaMessage(role='assistant', content=delta_text), False + + self._accumulated_text += delta_text + self._pending += delta_text + produced_any = False + + while True: + progressed = False + if self._mode == self.MODE_PLAIN: + emitted, progressed = self._consume_plain() + if emitted: + self._queued_deltas.append(_QueuedDelta(DeltaMessage(role='assistant', content=emitted), False)) + produced_any = True + elif self._mode == self.MODE_REASONING: + emitted, progressed = self._consume_reasoning() + if emitted: + if self.enable_thinking is False: + self._queued_deltas.append(_QueuedDelta(DeltaMessage(role='assistant', content=emitted), False)) + else: + self._queued_deltas.append( + _QueuedDelta(DeltaMessage(role='assistant', reasoning_content=emitted), False)) + produced_any = True + if self._mode == self.MODE_TOOL: + # self._consume_plain() might change the mode to MODE_TOOL + # so we need to check the mode again + new_calls, progressed = self._consume_tool() + if new_calls: + self._queued_deltas.append( + _QueuedDelta(DeltaMessage(role='assistant', tool_calls=new_calls), True)) + produced_any = True + if not progressed: + break + + # 5. Special case: a trailing empty delta (delta_text == '') after non-empty + # output should be surfaced as an explicit empty content delta so that + # streaming clients see the final "no-op" chunk (some backends do this). + if ( + delta_text == '' + and not produced_any + and self._accumulated_text != '' + ): + self._queued_deltas.append(_QueuedDelta(DeltaMessage(role='assistant', content=''), False)) + if not self._queued_deltas: + return None, False + queued = self._queued_deltas.pop(0) + return queued.delta, queued.tool_calls_emitted + + def _consume_plain(self) -> tuple[str | None, bool]: + """Consume buffered text while in plain mode. + + Behavior: + - Finds the earliest protocol opening tag (reasoning/tool) in + ``self._pending``. + - If no full tag is present, emits only the safe plain-text prefix and + preserves possible partial-tag suffix for the next chunk. + - If a tag is found, emits text before the tag as plain content, + consumes the tag, and switches mode: + - reasoning open tag -> ``MODE_REASONING`` + - tool open tag -> ``MODE_TOOL`` (also initializes tool-call state) + + Returns: + ``(emitted_text, progressed)`` where ``emitted_text`` is the plain + content produced in this step (or ``None``), and ``progressed`` + indicates whether parser state/input was consumed. + """ + tags = [t for t in (self.profile.reasoning_open_tag, self.profile.tool_open_tag) if t] + if not tags: + if not self._pending: + return None, False + out = self._pending + self._pending = '' + return out, True + + # Find the earliest protocol open tag. + earliest_idx = -1 + earliest_tag = None + for tag in tags: + idx = self._pending.find(tag) + if idx >= 0 and (earliest_idx < 0 or idx < earliest_idx): + earliest_idx = idx + earliest_tag = tag + + # No protocol open tag found, treat the whole pending text as plain content. + if earliest_idx < 0: + if not self._pending: + return None, False + out = self._pending + self._pending = '' + return out, True + + # Emit content before protocol open tag. + prefix = self._pending[:earliest_idx] + self._pending = self._pending[earliest_idx + len(earliest_tag):] + if earliest_tag == self.profile.reasoning_open_tag: + self._mode = self.MODE_REASONING + else: + self._mode = self.MODE_TOOL + if self.tool_parser is not None: + self.tool_parser.start_tool_call() + return (prefix if prefix else None), True + + def _consume_reasoning(self) -> tuple[str | None, bool]: + """Consume buffered text while in reasoning mode. + + Behavior: + - Drops the explicit open tag if model emits it. + - If no close tag is present, emits only the safe reasoning-text prefix and + preserves possible partial-tag suffix for the next chunk. + - If a close tag is found, emits text before the close tag as reasoning content, + consumes the close tag, and switches mode to ``MODE_PLAIN``. + + Returns: + ``(emitted_text, progressed)`` where ``emitted_text`` is the reasoning + content produced in this step (or ``None``), and ``progressed`` + indicates whether parser state/input was consumed. + """ + + open_tag = self.profile.reasoning_open_tag + # Drop explicit open tag if model emits it. + if open_tag and self._pending.startswith(open_tag): + self._pending = self._pending[len(open_tag):] + return None, True + + close_tag = self.profile.reasoning_close_tag + if not close_tag: + raise RuntimeError('Invariant violated: MODE_REASONING requires a reasoning_close_tag.') + + idx = self._pending.find(close_tag) + # No close tag found, treat the whole pending text as reasoning content. + if idx < 0: + if not self._pending: + return None, False + out = self._pending + self._pending = '' + return out, True + + reasoning_chunk = self._pending[:idx] + self._pending = self._pending[idx + len(close_tag):] + # reasoning part is done, switch to plain mode + self._mode = self.MODE_PLAIN + return (reasoning_chunk if reasoning_chunk else None), True + + def _consume_tool(self) -> tuple[list[DeltaToolCall], bool]: + """Consume buffered text while in tool mode. + + Behavior: + - Treats ``self._pending`` as tool payload bytes until ``tool_close_tag`` + is found. + - For non-final payload chunks, forwards text to + ``tool_parser.decode_tool_incremental(..., final=False)``. + - For the final payload chunk (before close tag), forwards text with + ``final=True``, then calls ``tool_parser.finish_tool_call()`` and + switches mode back to ``MODE_PLAIN``. + - This method is format-agnostic: JSON/XML/other details are handled + entirely by the concrete tool parser implementation. + + Returns: + ``(tool_call_deltas, progressed)`` where ``tool_call_deltas`` is the + list emitted by the tool parser for this step (possibly empty), and + ``progressed`` indicates whether parser state/input was consumed. + """ + if self.tool_parser is None: + raise RuntimeError('Invariant violated: MODE_TOOL requires a tool_parser.') + + close_tag = self.profile.tool_close_tag + if not close_tag: + if not self._pending: + return [], False + emit = self._pending + self._pending = '' + return self.tool_parser.decode_tool_incremental(added_text=emit, final=False), True + + idx = self._pending.find(close_tag) + + if idx < 0: + if not self._pending: + return [], False + emit = self._pending + self._pending = '' + return self.tool_parser.decode_tool_incremental(added_text=emit, final=False), True + + # Final chunk inside tool block. + inner = self._pending[:idx] + self._pending = self._pending[idx + len(close_tag):] + calls = self.tool_parser.decode_tool_incremental(added_text=inner, final=True) + self.tool_parser.finish_tool_call() + self._mode = self.MODE_PLAIN + return calls, True + + def _build_profile(self) -> ProtocolProfile: + profile = ProtocolProfile(starts_in_reasoning_mode=False) + rparser = self.reasoning_parser + tparser = self.tool_parser + + if rparser is not None: + profile.reasoning_open_tag = rparser.get_reasoning_open_tag() + profile.reasoning_close_tag = rparser.get_reasoning_close_tag() + profile.starts_in_reasoning_mode = bool(rparser.starts_in_reasoning_mode()) + if not profile.reasoning_close_tag: + raise RuntimeError(f'Reasoning parser {rparser.__class__.__name__} must provide a reasoning end tag') + + if tparser is not None and self.request.tool_choice != 'none': + profile.tool_open_tag = tparser.get_tool_open_tag() + profile.tool_close_tag = tparser.get_tool_close_tag() + profile.tool_payload_format = tparser.get_tool_payload_format() + if not profile.tool_open_tag: + raise RuntimeError(f'Tool parser {tparser.__class__.__name__} must provide a tool start tag') + if not profile.tool_close_tag: + raise RuntimeError(f'Tool parser {tparser.__class__.__name__} must provide a tool end tag') + return profile + + def parse_complete( + self, + text: str, + **kwargs, + ) -> tuple[str, list | None, str | None]: + """Parse the final non-streaming text output. + + Args: + text: Full generated output text. + + Returns: + A tuple ``(content, tool_calls, reasoning_content)``: + - ``content``: plain assistant-visible text, or ``None`` + - ``tool_calls``: parsed tool calls, or ``None`` + - ``reasoning_content``: separated reasoning text, or ``None`` + """ + content_parts: list[str] = [] + reasoning_parts: list[str] = [] + tool_calls: list[ToolCall] = [] + pos = 0 + mode = self.MODE_REASONING if (self.profile.starts_in_reasoning_mode and self.reasoning_parser is not None + and self.enable_thinking is not False) else self.MODE_PLAIN + n = len(text) + + while pos < n: + if mode == self.MODE_REASONING: + close_tag = self.profile.reasoning_close_tag + close_idx = text.find(close_tag, pos) if close_tag else -1 + if close_idx < 0: + piece = text[pos:] + if self.enable_thinking is False: + content_parts.append(piece) + else: + reasoning_parts.append(piece) + break + piece = text[pos:close_idx] + if piece: + if self.enable_thinking is False: + content_parts.append(piece) + else: + reasoning_parts.append(piece) + pos = close_idx + len(close_tag) + mode = self.MODE_PLAIN + continue + + open_idx, open_tag = self._find_first( + text, + [t for t in (self.profile.reasoning_open_tag, self.profile.tool_open_tag) if t], + pos, + ) + if open_idx < 0: + content_parts.append(text[pos:]) + break + + if open_idx > pos: + content_parts.append(text[pos:open_idx]) + + if open_tag == self.profile.reasoning_open_tag: + mode = self.MODE_REASONING + pos = open_idx + len(open_tag) + continue + + # tool block + close_tag = self.profile.tool_close_tag + close_idx = text.find(close_tag, open_idx + len(open_tag)) if close_tag else -1 + if close_idx < 0: + # Unterminated tool block: keep as plain text. + content_parts.append(text[open_idx:]) + break + tool_payload = text[open_idx + len(open_tag):close_idx].strip() + parsed_call = self.tool_parser.parse_tool_call_complete(tool_payload) if self.tool_parser else None + if parsed_call is not None: + tool_calls.append(parsed_call) + pos = close_idx + len(close_tag) + + content = ''.join(content_parts) + reasoning_content = ''.join(reasoning_parts) if reasoning_parts else None + return content if content != '' else None, tool_calls or None, reasoning_content + + @staticmethod + def _find_first(text: str, tags: list[str], start: int) -> tuple[int, str]: + best_idx = -1 + best_tag = '' + for tag in tags: + idx = text.find(tag, start) + if idx >= 0 and (best_idx < 0 or idx < best_idx): + best_idx = idx + best_tag = tag + return best_idx, best_tag diff --git a/lmdeploy/serve/openai/tool_parser/__init__.py b/lmdeploy/serve/parsers/tool_parser/__init__.py similarity index 52% rename from lmdeploy/serve/openai/tool_parser/__init__.py rename to lmdeploy/serve/parsers/tool_parser/__init__.py index e1e2b2726e..c138ce5dfb 100644 --- a/lmdeploy/serve/openai/tool_parser/__init__.py +++ b/lmdeploy/serve/parsers/tool_parser/__init__.py @@ -1,17 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .internlm2_parser import Internlm2ToolParser -from .llama3_parser import Llama3JsonToolParser -from .qwen2d5_parser import Qwen2d5ToolParser -from .qwen3_parser import Qwen3ToolParser -from .qwen3coder_parser import Qwen3CoderToolParser +from .internlm2_tool_parser import Internlm2ToolParser +from .llama3_tool_parser import Llama3JsonToolParser +from .qwen2d5_tool_parser import Qwen2d5ToolParser +from .qwen3_tool_parser import Qwen3ToolParser +from .qwen3coder_tool_parser import Qwen3CoderToolParser from .tool_parser import ToolParser, ToolParserManager __all__ = [ + 'ToolParser', + 'ToolParserManager', 'Internlm2ToolParser', + 'Llama3JsonToolParser', 'Qwen2d5ToolParser', 'Qwen3ToolParser', 'Qwen3CoderToolParser', - 'ToolParser', - 'ToolParserManager', - 'Llama3JsonToolParser', ] diff --git a/lmdeploy/serve/parsers/tool_parser/internlm2_tool_parser.py b/lmdeploy/serve/parsers/tool_parser/internlm2_tool_parser.py new file mode 100644 index 0000000000..a980d393d0 --- /dev/null +++ b/lmdeploy/serve/parsers/tool_parser/internlm2_tool_parser.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .tool_parser import ToolParser, ToolParserManager + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase + + from lmdeploy.serve.openai.protocol import ( + ChatCompletionRequest, + DeltaToolCall, + ToolCall, + ) + +@ToolParserManager.register_module(['internlm', 'intern-s1']) +class Internlm2ToolParser(ToolParser): + """Tool parser for InternLM JSON tool-call payloads.""" + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != 'none': + # do not skip special tokens because internlm use the special + # tokens to indicated the start and end of the tool calls + # information. + request.skip_special_tokens = False + return request + + def get_tool_open_tag(self) -> str | None: + return '<|action_start|><|plugin|>' + + def get_tool_close_tag(self) -> str | None: + return '<|action_end|>' + + def get_tool_payload_format(self) -> str: + return 'json' + + def decode_tool_incremental(self, added_text: str, *, final: bool) -> list[DeltaToolCall]: + """Decode incremental JSON tool payload.""" + return self._decode_tool_incremental_json(added_text=added_text, final=final) + + def parse_tool_call_complete(self, payload: str) -> ToolCall | None: + return self._parse_tool_call_complete_json(payload) diff --git a/lmdeploy/serve/parsers/tool_parser/llama3_tool_parser.py b/lmdeploy/serve/parsers/tool_parser/llama3_tool_parser.py new file mode 100644 index 0000000000..04b23fff16 --- /dev/null +++ b/lmdeploy/serve/parsers/tool_parser/llama3_tool_parser.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from lmdeploy.serve.openai.protocol import ( + DeltaToolCall, + ToolCall, +) + +from .tool_parser import ToolParser, ToolParserManager + + +@ToolParserManager.register_module('llama3') +class Llama3JsonToolParser(ToolParser): + """Tool parser for Llama3 JSON tool-call payloads.""" + + def __init__(self, tokenizer: object): + super().__init__(tokenizer) + self.bot_token = '<|python_tag|>' + + def get_tool_open_tag(self) -> str | None: + return self.bot_token + + def get_tool_close_tag(self) -> str | None: + return None + + def get_tool_payload_format(self) -> str: + return 'json' + + def decode_tool_incremental(self, added_text: str, *, final: bool) -> list[DeltaToolCall]: + """Decode incremental JSON tool payload.""" + return self._decode_tool_incremental_json(added_text=added_text, final=final) + + def parse_tool_call_complete(self, payload: str) -> ToolCall | None: + return self._parse_tool_call_complete_json(payload) diff --git a/lmdeploy/serve/parsers/tool_parser/qwen2d5_tool_parser.py b/lmdeploy/serve/parsers/tool_parser/qwen2d5_tool_parser.py new file mode 100644 index 0000000000..bdaa45a1f5 --- /dev/null +++ b/lmdeploy/serve/parsers/tool_parser/qwen2d5_tool_parser.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. + + +from lmdeploy.serve.openai.protocol import ( + DeltaToolCall, + ToolCall, +) + +from .tool_parser import ToolParser, ToolParserManager + + +@ToolParserManager.register_module(['qwen2d5']) +class Qwen2d5ToolParser(ToolParser): + """Tool parser for Qwen2.5 JSON tool-call payloads.""" + + def __init__(self, tokenizer: object): + super().__init__(tokenizer) + self.tool_start_token = '' + self.tool_end_token = '' + + def get_tool_open_tag(self) -> str | None: + return self.tool_start_token + + def get_tool_close_tag(self) -> str | None: + return self.tool_end_token + + def get_tool_payload_format(self) -> str: + return 'json' + + def decode_tool_incremental(self, added_text: str, *, final: bool) -> list[DeltaToolCall]: + """Decode incremental JSON tool payload.""" + return self._decode_tool_incremental_json(added_text=added_text, final=final) + + def parse_tool_call_complete(self, payload: str) -> ToolCall | None: + return self._parse_tool_call_complete_json(payload) diff --git a/lmdeploy/serve/parsers/tool_parser/qwen3_tool_parser.py b/lmdeploy/serve/parsers/tool_parser/qwen3_tool_parser.py new file mode 100644 index 0000000000..58a2189616 --- /dev/null +++ b/lmdeploy/serve/parsers/tool_parser/qwen3_tool_parser.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from lmdeploy.serve.openai.protocol import ( + DeltaToolCall, + ToolCall, +) + +from .tool_parser import ToolParser, ToolParserManager + + +@ToolParserManager.register_module(['qwen', 'qwen3']) +class Qwen3ToolParser(ToolParser): + """Tool parser for Qwen3 JSON tool-call payloads.""" + + def __init__(self, tokenizer: object): + super().__init__(tokenizer) + self.tool_start_token = '' + self.tool_end_token = '' + + def get_tool_open_tag(self) -> str | None: + return self.tool_start_token + + def get_tool_close_tag(self) -> str | None: + return self.tool_end_token + + def get_tool_payload_format(self) -> str: + return 'json' + + def decode_tool_incremental(self, added_text: str, *, final: bool) -> list[DeltaToolCall]: + """Decode incremental JSON tool payload.""" + return self._decode_tool_incremental_json(added_text=added_text, final=final) + + def parse_tool_call_complete(self, payload: str) -> ToolCall | None: + return self._parse_tool_call_complete_json(payload) diff --git a/lmdeploy/serve/parsers/tool_parser/qwen3coder_tool_parser.py b/lmdeploy/serve/parsers/tool_parser/qwen3coder_tool_parser.py new file mode 100644 index 0000000000..35f7771a51 --- /dev/null +++ b/lmdeploy/serve/parsers/tool_parser/qwen3coder_tool_parser.py @@ -0,0 +1,222 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +from lmdeploy.serve.openai.protocol import ( + DeltaFunctionCall, + DeltaToolCall, + FunctionCall, + ToolCall, +) + +from .tool_parser import ToolParser, ToolParserManager + +if TYPE_CHECKING: + from lmdeploy.serve.openai.protocol import ChatCompletionRequest + + +def _parse_tool_call_arguments_dict(arguments: Any) -> dict[str, Any] | None: + """Return dict-like tool arguments for Qwen3Coder request normalization.""" + if not isinstance(arguments, str): + return None + + try: + parsed_arguments = json.loads(arguments) + except (json.JSONDecodeError, TypeError): + return None + if isinstance(parsed_arguments, dict): + return parsed_arguments + return None + + +@ToolParserManager.register_module(['qwen3coder']) +class Qwen3CoderToolParser(ToolParser): + """Tool parser for Qwen3Coder XML tool-call payloads.""" + + def __init__(self, tokenizer: object): + super().__init__(tokenizer) + self.tool_start_token = '' + self.tool_end_token = '' + self.func_prefix = ' list[dict] | None: + """Return a render-safe copy of request messages when needed.""" + normalized_messages = None + + for msg_idx, message in enumerate(messages): + if not isinstance(message, dict) or message.get('role') != 'assistant': + continue + tool_calls = message.get('tool_calls') + if not isinstance(tool_calls, list): + continue + + normalized_tool_calls = None + for tool_idx, tool_call in enumerate(tool_calls): + if not isinstance(tool_call, dict): + continue + function = tool_call.get('function') + if not isinstance(function, dict) or isinstance(function.get('arguments'), dict): + continue + + parsed_arguments = _parse_tool_call_arguments_dict(function.get('arguments')) + if parsed_arguments is None: + continue + + if normalized_messages is None: + normalized_messages = list(messages) + if normalized_tool_calls is None: + normalized_tool_calls = list(tool_calls) + normalized_message = dict(message) + normalized_message['tool_calls'] = normalized_tool_calls + normalized_messages[msg_idx] = normalized_message + + normalized_function = dict(function) + normalized_function['arguments'] = parsed_arguments + + normalized_tool_call = dict(tool_call) + normalized_tool_call['function'] = normalized_function + normalized_tool_calls[tool_idx] = normalized_tool_call + + return normalized_messages + + def get_tool_open_tag(self) -> str | None: + return self.tool_start_token + + def get_tool_close_tag(self) -> str | None: + return self.tool_end_token + + def get_tool_payload_format(self) -> str: + return 'xml' + + def start_tool_call(self) -> None: + super().start_tool_call() + self.coder_has_emitted_name = False + self.coder_has_emitted_json_start = False + self.coder_json_closed = False + self.coder_emitted_param_names.clear() + + def finish_tool_call(self) -> None: + super().finish_tool_call() + self.coder_has_emitted_name = False + self.coder_has_emitted_json_start = False + self.coder_json_closed = False + self.coder_emitted_param_names.clear() + + def decode_tool_incremental(self, added_text: str, *, final: bool) -> list[DeltaToolCall]: + """Decode incremental XML tool payload.""" + self._tool_payload += added_text + func_name, args_dict, is_func_closed = self._extract_params(self._tool_payload) + + out: list[DeltaToolCall] = [] + if func_name and not self.coder_has_emitted_name: + out.append( + DeltaToolCall( + id=self._active_tool_call_id, + index=self._active_tool_index, + type='function', + function=DeltaFunctionCall(name=func_name), + )) + self.coder_has_emitted_name = True + + json_fragments: list[str] = [] + if not self.coder_has_emitted_json_start and (args_dict or is_func_closed): + json_fragments.append('{') + self.coder_has_emitted_json_start = True + + for k, v in args_dict.items(): + if k in self.coder_emitted_param_names: + continue + prefix = ', ' if len(self.coder_emitted_param_names) > 0 else '' + json_fragments.append(f'{prefix}\"{k}\": {json.dumps(v, ensure_ascii=False)}') + self.coder_emitted_param_names.add(k) + + if is_func_closed and self.coder_has_emitted_json_start and not self.coder_json_closed: + json_fragments.append('}') + self.coder_json_closed = True + + if json_fragments: + out.append( + DeltaToolCall( + id=self._active_tool_call_id, + index=self._active_tool_index, + type=None, + function=DeltaFunctionCall(arguments=''.join(json_fragments)), + )) + return out + + def parse_tool_call_complete(self, payload: str) -> ToolCall | None: + func_name, args_dict, _ = self._extract_params(payload) + if not func_name: + return None + args_json = json.dumps(args_dict, ensure_ascii=False) if args_dict else '{}' + return ToolCall(function=FunctionCall(name=func_name, arguments=args_json)) + + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + messages = request.messages + if not isinstance(messages, list): + return request + + normalized_messages = self._normalize_request_messages(messages) + if normalized_messages is None: + return request + return request.model_copy(update={'messages': normalized_messages}) + + def _extract_params(self, content: str) -> tuple[str | None, dict[str, Any], bool]: + """Extract function name, parameter map, and close status from XML.""" + content = content.replace(self.tool_start_token, '').replace(self.tool_end_token, '').strip() + + func_name = None + func_start = content.find(self.func_prefix) + if func_start != -1: + name_start = func_start + len(self.func_prefix) + terminators = [idx for idx in (content.find('>', name_start), content.find('\n', name_start)) if idx != -1] + if terminators: + func_name = content[name_start:min(terminators)].strip() + + args_dict = {} + search_idx = 0 + while True: + param_start = content.find(self.param_prefix, search_idx) + if param_start == -1: + break + + name_start = param_start + len(self.param_prefix) + terminators = [idx for idx in (content.find('>', name_start), content.find('\n', name_start)) if idx != -1] + if not terminators: + break + + name_end = min(terminators) + param_name = content[name_start:name_end].strip() + + val_start = name_end + 1 + val_end = content.find(self.param_end_token, val_start) + if val_end == -1: + break + + param_val_str = content[val_start:val_end].strip() + + if param_val_str.lower() == 'null': + val = None + elif param_val_str.lower() == 'true': + val = True + elif param_val_str.lower() == 'false': + val = False + else: + try: + val = json.loads(param_val_str) + except json.JSONDecodeError: + val = param_val_str + args_dict[param_name] = val + search_idx = val_end + len(self.param_end_token) + + is_func_closed = self.func_end_token in content + return func_name, args_dict, is_func_closed diff --git a/lmdeploy/serve/parsers/tool_parser/tool_parser.py b/lmdeploy/serve/parsers/tool_parser/tool_parser.py new file mode 100644 index 0000000000..23697299c9 --- /dev/null +++ b/lmdeploy/serve/parsers/tool_parser/tool_parser.py @@ -0,0 +1,156 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# modified from https://github.com/vllm-project/vllm/tree/v0.7.3/vllm/entrypoints/openai/tool_parsers +from __future__ import annotations + +import json +from functools import cached_property +from typing import TYPE_CHECKING + +import partial_json_parser +import shortuuid +from mmengine import Registry +from partial_json_parser.core.options import Allow + +from lmdeploy.serve.openai.protocol import ( + DeltaFunctionCall, + DeltaToolCall, + FunctionCall, + ToolCall, +) + +if TYPE_CHECKING: + from lmdeploy.serve.openai.protocol import ChatCompletionRequest + +ToolParserManager = Registry('tool_parser', locations=['lmdeploy.serve.parsers.tool_parser']) + + +class ToolParser: + """Base class for model-specific tool parsers.""" + + def __init__(self, tokenizer: object): + self.model_tokenizer = tokenizer + self._tool_payload: str = '' + self._active_tool_call_id: str = '' + self._active_tool_index: int = -1 + self._name_emitted: bool = False + self._args_emitted_len: int = 0 + + @cached_property + def vocab(self) -> dict[str, int]: + # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab + # whereas all tokenizers have .get_vocab() + return self.model_tokenizer.get_vocab() + + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + """Adjust request payload before rendering, if needed.""" + if request.tools is not None and request.tool_choice != 'none': + if not isinstance(request.tool_choice, str): + request.tools = [ + item.function.model_dump() for item in request.tools + if item.function.name == request.tool_choice.function.name + ] + else: + request.tools = [item.function.model_dump() for item in request.tools] + return request + + def get_tool_open_tag(self) -> str | None: + """Return tool opening tag string, or None if unsupported.""" + raise NotImplementedError('ToolParser.get_tool_open_tag has not been implemented!') + + def get_tool_close_tag(self) -> str | None: + """Return tool closing tag string, or None if unsupported.""" + raise NotImplementedError('ToolParser.get_tool_close_tag has not been implemented!') + + def get_tool_payload_format(self) -> str: + """Return payload format for tool call body.""" + raise NotImplementedError('ToolParser.get_tool_payload_format has not been implemented!') + + def start_tool_call(self) -> None: + """Mark start of a tool-call block.""" + self._active_tool_index += 1 + self._active_tool_call_id = f'chatcmpl-tool-{shortuuid.random()}' + self._name_emitted = False + self._args_emitted_len = 0 + self._tool_payload = '' + + def finish_tool_call(self) -> None: + """Mark end of a tool-call block.""" + self._active_tool_call_id = '' + self._name_emitted = False + self._args_emitted_len = 0 + self._tool_payload = '' + + def decode_tool_incremental(self, added_text: str, *, final: bool) -> list[DeltaToolCall]: + """Decode incremental tool payload emitted between tool tags.""" + raise NotImplementedError('ToolParser.decode_tool_incremental has not been implemented!') + + def parse_tool_call_complete(self, payload: str) -> ToolCall | None: + """Parse one complete tool payload into OpenAI tool call object.""" + raise NotImplementedError('ToolParser.parse_tool_call_complete has not been implemented!') + + def _decode_tool_incremental_json(self, added_text: str, *, final: bool) -> list[DeltaToolCall]: + self._tool_payload += added_text + payload = self._tool_payload.strip() + if not payload: + return [] + + flags = Allow.ALL if self._name_emitted else Allow.ALL & ~Allow.STR + try: + obj = partial_json_parser.loads(payload, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + return [] + if not isinstance(obj, dict): + return [] + + out: list[DeltaToolCall] = [] + if not self._name_emitted: + fn_name = obj.get('name') + if isinstance(fn_name, str) and fn_name: + out.append( + DeltaToolCall( + id=self._active_tool_call_id, + index=self._active_tool_index, + type='function', + function=DeltaFunctionCall(name=fn_name), + )) + self._name_emitted = True + + args_obj = obj.get('arguments', obj.get('parameters', None)) + if args_obj is None: + return out + + args_json = json.dumps(args_obj, ensure_ascii=False) + if args_json in ('{}', '[]'): + return out + + # Emit argument text only when the tool payload is complete. This keeps + # streamed argument chunks valid JSON and avoids malformed intermediate + # fragments when partial parsers expose transient dict states. + if final and len(args_json) > self._args_emitted_len: + diff = args_json[self._args_emitted_len:] + out.append( + DeltaToolCall( + id=self._active_tool_call_id, + index=self._active_tool_index, + type=None, + function=DeltaFunctionCall(arguments=diff), + )) + self._args_emitted_len = len(args_json) + return out + + @staticmethod + def _parse_tool_call_complete_json(payload: str) -> ToolCall | None: + if not payload: + return None + try: + obj = json.loads(payload) + except json.JSONDecodeError: + return None + if not isinstance(obj, dict): + return None + name = obj.get('name') + if not isinstance(name, str) or not name: + return None + args_obj = obj.get('arguments', obj.get('parameters', {})) + args_json = json.dumps(args_obj, ensure_ascii=False) + return ToolCall(function=FunctionCall(name=name, arguments=args_json)) diff --git a/tests/test_lmdeploy/serve/parsers/test_deepseek_v3_parser.py b/tests/test_lmdeploy/serve/parsers/test_deepseek_v3_parser.py new file mode 100644 index 0000000000..0080cba594 --- /dev/null +++ b/tests/test_lmdeploy/serve/parsers/test_deepseek_v3_parser.py @@ -0,0 +1,47 @@ +import pytest +from transformers import AutoTokenizer + +from lmdeploy.serve.openai.protocol import ChatCompletionRequest +from lmdeploy.serve.parsers import ResponseParserManager +from lmdeploy.serve.parsers.reasoning_parser import ReasoningParserManager + +MODEL_ID = 'deepseek-ai/DeepSeek-V3.1' + +@pytest.fixture(scope='module') +def tokenizer(): + try: + return AutoTokenizer.from_pretrained(MODEL_ID) + except Exception as exc: # noqa: BLE001 + pytest.skip(f'Could not load tokenizer for {MODEL_ID}: {exc}') + + +def _make_parser(enable_thinking, tokenizer): + cls = ResponseParserManager.get('default') + cls.reasoning_parser_cls = ReasoningParserManager.get('deepseek-v3') + cls.tool_parser_cls = None + request = ChatCompletionRequest( + model=MODEL_ID, + messages=[], + stream=True, + chat_template_kwargs={'enable_thinking': enable_thinking}, + ) + return cls(request=request, tokenizer=tokenizer) + + +class TestDeepSeekV3ReasoningParser: + + def test_enable_thinking_none(self, tokenizer): + parser = _make_parser(enable_thinking=None, tokenizer=tokenizer) + delta_msg, tool_emitted = parser.stream_chunk(delta_text='hello', delta_token_ids=[]) + assert tool_emitted is False + assert delta_msg is not None + assert delta_msg.content == 'hello' + assert delta_msg.reasoning_content is None + + def test_enable_thinking_true(self, tokenizer): + parser = _make_parser(enable_thinking=True, tokenizer=tokenizer) + delta_msg, tool_emitted = parser.stream_chunk(delta_text='hello', delta_token_ids=[]) + assert tool_emitted is False + assert delta_msg is not None + assert delta_msg.content is None + assert delta_msg.reasoning_content == 'hello' diff --git a/tests/test_lmdeploy/serve/parsers/test_gpt_oss_parser.py b/tests/test_lmdeploy/serve/parsers/test_gpt_oss_parser.py new file mode 100644 index 0000000000..9febac52db --- /dev/null +++ b/tests/test_lmdeploy/serve/parsers/test_gpt_oss_parser.py @@ -0,0 +1,266 @@ +from dataclasses import dataclass + +import pytest + +pytest.importorskip('openai_harmony') + +from lmdeploy.serve.parsers import _openai_harmony as openai_harmony_mod +from lmdeploy.serve.parsers import gpt_oss_response_parser as gpt_oss_mod + + +@dataclass +class _FakeMsg: + channel: str + recipient: str | None + + +class _FakeStreamableParser: + """Scripted stand-in for openai_harmony.StreamableParser.""" + + def __init__(self, script: dict[int, dict]): + self._script = script + self.current_channel = 'final' + self.current_recipient = None + self.last_content_delta = '' + self.messages: list[_FakeMsg] = [] + + def process(self, token: int): + event = self._script[token] + next_channel = event['channel'] + next_recipient = event.get('recipient') + + if (self.current_channel == 'commentary' and self.current_recipient + and self.current_recipient.startswith('functions.') and next_recipient != self.current_recipient): + self.messages.append(_FakeMsg(channel='commentary', recipient=self.current_recipient)) + + self.current_channel = next_channel + self.current_recipient = next_recipient + self.last_content_delta = event.get('delta', '') + + +def _scripted_events() -> dict[int, dict]: + return { + 1: { + 'channel': 'analysis', + 'recipient': None, + 'delta': 'Need tool. ', + }, + 2: { + 'channel': 'commentary', + 'recipient': 'functions.get_weather', + 'delta': '', + }, + 3: { + 'channel': 'commentary', + 'recipient': 'functions.get_weather', + 'delta': '{"location":"', + }, + 4: { + 'channel': 'commentary', + 'recipient': 'functions.get_weather', + 'delta': 'Beijing"}', + }, + 5: { + 'channel': 'commentary', + 'recipient': 'functions.get_time', + 'delta': '', + }, + 6: { + 'channel': 'commentary', + 'recipient': 'functions.get_time<|channel|>commentary', + 'delta': '{"tz":"UTC"}', + }, + 7: { + 'channel': 'final', + 'recipient': None, + 'delta': 'Result: ', + }, + 8: { + 'channel': 'final', + 'recipient': None, + 'delta': 'sunny', + }, + } + + +class TestGptOssResponseParser: + """Unit tests for :class:`GptOssResponseParser` (Harmony token + streaming).""" + + def test_stream_chunk_full_sequence(self, monkeypatch): + monkeypatch.setattr( + openai_harmony_mod, + 'StreamableParser', + lambda *args, **kwargs: _FakeStreamableParser(_scripted_events()), + ) + parser = gpt_oss_mod.GptOssResponseParser(request=object(), tokenizer=object()) + + delta, tool_emitted = parser.stream_chunk(delta_text='ignored', delta_token_ids=[1, 2, 3, 4, 5, 6, 7, 8]) + assert delta is not None + assert delta.content == 'Result: sunny' + assert delta.reasoning_content == 'Need tool. ' + assert tool_emitted is True + assert delta.tool_calls is not None + assert len(delta.tool_calls) == 5 + + # name delta + args delta for get_weather + assert delta.tool_calls[0].function is not None + assert delta.tool_calls[0].function.name == 'get_weather' + assert delta.tool_calls[1].function is not None + assert delta.tool_calls[1].function.arguments == '{"location":"' + assert delta.tool_calls[2].function is not None + assert delta.tool_calls[2].function.arguments == 'Beijing"}' + + # second tool: name delta + sanitized malformed recipient arguments delta. + assert delta.tool_calls[3].function is not None + assert delta.tool_calls[3].function.name == 'get_time' + assert delta.tool_calls[4].function is not None + assert delta.tool_calls[4].function.arguments == '{"tz":"UTC"}' + + def test_parse_complete_full_sequence(self, monkeypatch): + monkeypatch.setattr( + openai_harmony_mod, + 'StreamableParser', + lambda *args, **kwargs: _FakeStreamableParser(_scripted_events()), + ) + parser = gpt_oss_mod.GptOssResponseParser(request=object(), tokenizer=object()) + + content, tool_calls, reasoning = parser.parse_complete(text='', token_ids=[1, 2, 3, 4, 5, 6, 7, 8]) + assert content == 'Result: sunny' + assert reasoning == 'Need tool. ' + assert tool_calls is not None + assert [call.function.name for call in tool_calls] == ['get_weather', 'get_time'] + assert [call.function.arguments for call in tool_calls] == ['{"location":"Beijing"}', '{"tz":"UTC"}'] + + def test_stream_chunk_bootstrap_empty_before_any_content(self, monkeypatch): + monkeypatch.setattr( + openai_harmony_mod, + 'StreamableParser', + lambda *args, **kwargs: _FakeStreamableParser({}), + ) + parser = gpt_oss_mod.GptOssResponseParser(request=object(), tokenizer=object()) + + delta, tool_emitted = parser.stream_chunk('', []) + assert delta is not None + assert delta.role == 'assistant' + assert delta.content == '' + assert tool_emitted is False + + def test_stream_chunk_empty_after_content_started_returns_none(self, monkeypatch): + monkeypatch.setattr( + openai_harmony_mod, + 'StreamableParser', + lambda *args, **kwargs: _FakeStreamableParser({}), + ) + parser = gpt_oss_mod.GptOssResponseParser(request=object(), tokenizer=object()) + + parser.stream_chunk('warmup', []) + delta, tool_emitted = parser.stream_chunk('', []) + assert delta is None + assert tool_emitted is False + + def test_stream_chunk_text_only_without_token_ids(self, monkeypatch): + monkeypatch.setattr( + openai_harmony_mod, + 'StreamableParser', + lambda *args, **kwargs: _FakeStreamableParser({}), + ) + parser = gpt_oss_mod.GptOssResponseParser(request=object(), tokenizer=object()) + + delta, tool_emitted = parser.stream_chunk('plain text', []) + assert delta is not None + assert delta.content == 'plain text' + assert delta.reasoning_content is None + assert delta.tool_calls is None + assert tool_emitted is False + + def test_stream_chunk_token_ids_all_empty_delta_returns_none(self, monkeypatch): + script = { + 10: {'channel': 'final', 'recipient': None, 'delta': ''}, + } + monkeypatch.setattr( + openai_harmony_mod, + 'StreamableParser', + lambda *args, **kwargs: _FakeStreamableParser(script), + ) + parser = gpt_oss_mod.GptOssResponseParser(request=object(), tokenizer=object()) + + delta, tool_emitted = parser.stream_chunk('', [10]) + assert delta is None + assert tool_emitted is False + + def test_stream_chunk_analysis_without_tool_accumulates_reasoning(self, monkeypatch): + script = { + 1: {'channel': 'analysis', 'recipient': None, 'delta': 'think '}, + 2: {'channel': 'analysis', 'recipient': None, 'delta': 'more'}, + } + monkeypatch.setattr( + openai_harmony_mod, + 'StreamableParser', + lambda *args, **kwargs: _FakeStreamableParser(script), + ) + parser = gpt_oss_mod.GptOssResponseParser(request=object(), tokenizer=object()) + + delta, tool_emitted = parser.stream_chunk('', [1, 2]) + assert delta is not None + assert delta.content is None + assert delta.reasoning_content == 'think more' + assert delta.tool_calls is None + assert tool_emitted is False + + def test_parse_complete_without_token_ids_returns_raw_text(self): + parser = gpt_oss_mod.GptOssResponseParser(request=object(), tokenizer=object()) + + content, tool_calls, reasoning = parser.parse_complete('hello', token_ids=[]) + assert content == 'hello' + assert tool_calls is None + assert reasoning is None + + def test_parse_complete_without_token_ids_empty_text(self): + parser = gpt_oss_mod.GptOssResponseParser(request=object(), tokenizer=object()) + + content, tool_calls, reasoning = parser.parse_complete('', token_ids=None) + assert content is None + assert tool_calls is None + assert reasoning is None + + def test_parse_complete_appends_tool_call_still_open_at_eof(self, monkeypatch): + """Final `active` tool dict is appended when the stream ends in a tool + channel.""" + script = { + 1: { + 'channel': 'commentary', + 'recipient': 'functions.echo', + 'delta': '{"x":1}', + }, + } + monkeypatch.setattr( + openai_harmony_mod, + 'StreamableParser', + lambda *args, **kwargs: _FakeStreamableParser(script), + ) + parser = gpt_oss_mod.GptOssResponseParser(request=object(), tokenizer=object()) + + content, tool_calls, reasoning = parser.parse_complete(text='', token_ids=[1]) + assert content is None + assert reasoning is None + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == 'echo' + assert tool_calls[0].function.arguments == '{"x":1}' + + @pytest.mark.parametrize( + ('recipient', 'expected'), + [ + (None, None), + ('', None), + ('not-a-tool', None), + ('functions.', None), + ('functions.foo', 'foo'), + ('prefix functions.bar suffix', 'bar'), + ('functions.bash<|channel|>commentary', 'bash'), + ('functions.tool_name<|extra|', 'tool_name'), + ], + ) + def test_extract_tool_name(self, recipient, expected): + assert gpt_oss_mod.GptOssResponseParser._extract_tool_name(recipient) == expected diff --git a/tests/test_lmdeploy/serve/parsers/test_interns1_parser.py b/tests/test_lmdeploy/serve/parsers/test_interns1_parser.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_lmdeploy/serve/parsers/test_qwen3_5_parsers.py b/tests/test_lmdeploy/serve/parsers/test_qwen3_5_parsers.py new file mode 100644 index 0000000000..8d2eaf3be1 --- /dev/null +++ b/tests/test_lmdeploy/serve/parsers/test_qwen3_5_parsers.py @@ -0,0 +1,136 @@ +import pytest +from transformers import AutoTokenizer + +from lmdeploy.serve.openai.protocol import ChatCompletionRequest, DeltaToolCall +from lmdeploy.serve.parsers import ResponseParserManager +from lmdeploy.serve.parsers.reasoning_parser import ReasoningParserManager +from lmdeploy.serve.parsers.tool_parser import ToolParserManager + +MODEL_ID = 'Qwen/Qwen3.5-35B-A3B' + + +@pytest.fixture(scope='module') +def tokenizer(): + return AutoTokenizer.from_pretrained(MODEL_ID) + + +@pytest.fixture() +def response_parser(tokenizer): + cls = ResponseParserManager.get('default') + cls.reasoning_parser_cls = ReasoningParserManager.get('default') + cls.tool_parser_cls = ToolParserManager.get('qwen3coder') + + request = ChatCompletionRequest( + model=MODEL_ID, + messages=[], + stream=True, + tool_choice='auto', + chat_template_kwargs={'enable_thinking': True}, + ) + return cls(request=request, tokenizer=tokenizer) + + +REFERENCE_CHUNKS = [ + # (delta_text, emitted_delta_msg, reasoning_content, content, + # tool_emitted, function_name, function_arguments, tool_call_type) + # Short representative reasoning stream; literal text is irrelevant. + ('计划', True, '计划', None, False, None, None, None), + ('调用', True, '调用', None, False, None, None, None), + ('get', True, 'get', None, False, None, None, None), + ('_current', True, '_current', None, False, None, None, None), + ('_temperature', True, '_temperature', None, False, None, None, None), + ('函数', True, '函数', None, False, None, None, None), + ('并提供', True, '并提供', None, False, None, None, None), + ('location', True, 'location', None, False, None, None, None), + ('参数', True, '参数', None, False, None, None, None), + ('。', True, '。', None, False, None, None, None), + ('\n', True, '\n', None, False, None, None, None), + ('', False, None, None, False, None, None, None), + ('\n\n', True, None, '\n\n', False, None, None, None), + # Tool call section: placeholder; will be updated to match Qwen3.5 XML-style. + ('', False, None, None, False, None, None, None), + ('\n', False, None, None, False, None, None, None), + ('<', False, None, None, False, None, None, None), + ('function', False, None, None, False, None, None, None), + ('=get', False, None, None, False, None, None, None), + ('_current', False, None, None, False, None, None, None), + ('_temperature', False, None, None, False, None, None, None), + ('>', True, None, None, True, 'get_current_temperature', None, 'function'), + ('\n', False, None, None, False, None, None, None), + ('<', False, None, None, False, None, None, None), + ('parameter', False, None, None, False, None, None, None), + ('=location', False, None, None, False, None, None, None), + ('>', False, None, None, False, None, None, None), + ('\n', False, None, None, False, None, None, None), + ('Be', False, None, None, False, None, None, None), + ('ijing', False, None, None, False, None, None, None), + (',', False, None, None, False, None, None, None), + (' China', False, None, None, False, None, None, None), + ('\n', False, None, None, False, None, None, None), + ('` to a single id; Qwen3Coder may emit accumulated JSON args in one delta. + ('>', True, None, None, True, None, '{"location": "Beijing, China"', None), + ('\n', False, None, None, False, None, None, None), + ('', True, None, None, True, None, '}', None), + ('\n', False, None, None, False, None, None, None), + ('', False, None, None, False, None, None, None), + ('', True, None, '', False, None, None, None), +] + + +class TestQwen3_5ResponseParserStreaming: + """Integration test for ResponseParser.stream_chunk with Qwen3.5 Coder + parsers.""" + + @staticmethod + def _encode_ids(tokenizer, text: str) -> list[int]: + return tokenizer.encode(text, add_bos=False, add_special_tokens=False) + + def test_stream_chunk_matches_reference(self, tokenizer, response_parser): + """Feed the real streaming sequence into ResponseParser.stream_chunk + and verify each parsed chunk. + + Expectations for tool_calls will be refined once the Qwen3.5 ground-truth stream is finalized. + """ + + for (delta_text, exp_delta_msg, exp_reasoning, exp_content, exp_tool_emitted, + exp_function_name, exp_function_arguments, + exp_type) in REFERENCE_CHUNKS: + delta_ids = self._encode_ids(tokenizer, delta_text) + delta_msg, tool_emitted = response_parser.stream_chunk( + delta_text=delta_text, + delta_token_ids=delta_ids, + ) + if exp_delta_msg is False: + assert delta_msg is None + continue + + assert delta_msg.reasoning_content == exp_reasoning + assert delta_msg.content == exp_content + + # Tool-call expectations in this fixture are placeholders for now. + # Only enforce the exact tool_emitted flag when an explicit tool + # delta shape is provided. + if ( + exp_function_name is None + and exp_function_arguments is None + and exp_type is None + and exp_reasoning is None + and exp_content is None + ): + continue + + assert tool_emitted == exp_tool_emitted + + if tool_emitted: + assert delta_msg.tool_calls is not None + assert len(delta_msg.tool_calls) == 1 + call = delta_msg.tool_calls[0] + assert isinstance(call, DeltaToolCall) + assert call.type == exp_type + assert call.function is not None + assert call.function.name == exp_function_name + assert call.function.arguments == exp_function_arguments diff --git a/tests/test_lmdeploy/serve/parsers/test_qwen3_parser.py b/tests/test_lmdeploy/serve/parsers/test_qwen3_parser.py new file mode 100644 index 0000000000..9750fdf4f2 --- /dev/null +++ b/tests/test_lmdeploy/serve/parsers/test_qwen3_parser.py @@ -0,0 +1,375 @@ +import pytest +from transformers import AutoTokenizer + +from lmdeploy.serve.openai.protocol import ChatCompletionRequest, DeltaToolCall +from lmdeploy.serve.parsers import ResponseParserManager +from lmdeploy.serve.parsers.reasoning_parser import ReasoningParserManager +from lmdeploy.serve.parsers.tool_parser import ToolParserManager + +MODEL_ID = 'Qwen/Qwen3-8B' + + +@pytest.fixture(scope='module') +def tokenizer(): + try: + return AutoTokenizer.from_pretrained(MODEL_ID) + except Exception as exc: # noqa: BLE001 + pytest.skip(f'Could not load tokenizer for {MODEL_ID}: {exc}') + + +@pytest.fixture() +def response_parser(tokenizer): + # Configure ResponseParser to use unified reasoning parser and Qwen3 tool parser. + cls = ResponseParserManager.get('default') + cls.reasoning_parser_cls = ReasoningParserManager.get('default') + cls.tool_parser_cls = ToolParserManager.get('qwen3') + + request = ChatCompletionRequest( + model=MODEL_ID, + messages=[], + stream=True, + # Enable tool parsing (any value other than "none" works). + tool_choice='auto', + # Explicitly enable thinking mode to exercise reasoning parsing. + chat_template_kwargs={'enable_thinking': True}, + ) + return cls(request=request, tokenizer=tokenizer) + + +# Reference streaming sequence +# reasoning part: This is the mock user prompt +REASONING_0 = [ + # (delta_text, emitted_delta_msg, reasoning_content, content, + # tool_emitted, function_name, function_arguments, tool_call_type) + # reasoning part + ('', False, None, None, False, None, None, None), + ('This is the mock', True, 'This is the mock', None, False, None, None, None), + (' user prompt', True, ' user prompt', None, False, None, None, None), + ('', False, None, None, False, None, None, None), +] +# reasoning part: This is the mock user prompt +REASONING_1 = [ + # (delta_text, emitted_delta_msg, reasoning_content, content, + # tool_emitted, function_name, function_arguments, tool_call_type) + # reasoning part + ('This is the mock', True, 'This is the mock', None, False, None, None, None), + (' user prompt', True, ' user prompt', None, False, None, None, None), + ('', False, None, None, False, None, None, None), +] + +# tool call part: {"name": "get_weather", "arguments": {"location": "北京", "unit": "celsius"}} +TOOL_CALL_0 = [ + # (delta_text, emitted_delta_msg, reasoning_content, content, + # tool_emitted, function_name, function_arguments, tool_call_type) + # tool call part + ('', False, None, None, False, None, None, None), + ('\n', False, None, None, False, None, None, None), + ('{"', False, None, None, False, None, None, None), + ('name', False, None, None, False, None, None, None), + ('":', False, None, None, False, None, None, None), + (' "', False, None, None, False, None, None, None), + ('get', False, None, None, False, None, None, None), + ('_weather', False, None, None, False, None, None, None), + ('",', True, None, None, True, 'get_weather', None, 'function'), + (' "', False, None, None, False, None, None, None), + ('arguments', False, None, None, False, None, None, None), + ('":', False, None, None, False, None, None, None), + (' {"', False, None, None, False, None, None, None), + ('location', False, None, None, False, None, None, None), + ('":', False, None, None, False, None, None, None), + (' "', False, None, None, False, None, None, None), + ('北京', False, None, None, False, None, None, None), + ('",', False, None, None, False, None, None, None), + (' "', False, None, None, False, None, None, None), + ('unit', False, None, None, False, None, None, None), + ('":', False, None, None, False, None, None, None), + (' "', False, None, None, False, None, None, None), + ('celsius', False, None, None, False, None, None, None), + ('"}}\n', False, None, None, False, None, None, None), + ('', True, None, None, True, None, '{"location": "北京", "unit": "celsius"}', None), +] + +REFERENCE_CHUNKS_0 = REASONING_0 + [ + ('\n\n', True, None, '\n\n', False, None, None, None)] + TOOL_CALL_0 + [ + ('', True, None, '', False, None, None, None), +] + +REFERENCE_CHUNKS_1 = REASONING_1 + [ + ('\n\n', True, None, '\n\n', False, None, None, None)] + TOOL_CALL_0 + [ + ('', True, None, '', False, None, None, None), +] + +REFERENCE_CHUNKS_2 = [ + # (delta_text, emitted_delta_msg, reasoning_content, content, + # tool_emitted, function_name, function_arguments, tool_call_type) + # reasoning part + ('This is the mock', True, 'This is the mock', None, False, None, None, None), + (' user prompt.', True, ' user prompt.', None, False, None, None, None), + (' reasoning\n\n\n', True, ' reasoning', None, False, None, None, None), + ('{"', True, None, '\n\n', False, None, None, None), + ('name', False, None, None, False, None, None, None), + ('":', False, None, None, False, None, None, None), + (' "', False, None, None, False, None, None, None), + ('get', False, None, None, False, None, None, None), + ('_weather', False, None, None, False, None, None, None), + ('",', True, None, None, True, 'get_weather', None, 'function'), + (' "', False, None, None, False, None, None, None), + ('arguments', False, None, None, False, None, None, None), + ('":', False, None, None, False, None, None, None), + (' {"', False, None, None, False, None, None, None), + ('location', False, None, None, False, None, None, None), + ('":', False, None, None, False, None, None, None), + (' "', False, None, None, False, None, None, None), + ('北京', False, None, None, False, None, None, None), + ('",', False, None, None, False, None, None, None), + (' "', False, None, None, False, None, None, None), + ('unit', False, None, None, False, None, None, None), + ('":', False, None, None, False, None, None, None), + (' "', False, None, None, False, None, None, None), + ('celsius', False, None, None, False, None, None, None), + ('"}}\n', False, None, None, False, None, None, None), + ('', True, None, None, True, None, '{"location": "北京", "unit": "celsius"}', None), + ('', True, None, '', False, None, None, None), +] + + +class TestQwenResponseParserStreaming: + """Integration test for ResponseParser.stream_chunk with Qwen3 parsers.""" + + @staticmethod + def _encode_ids(tokenizer, text: str) -> list[int]: + return tokenizer.encode(text, add_bos=False, add_special_tokens=False) + + @pytest.mark.parametrize('reference_chunks', [REFERENCE_CHUNKS_0, REFERENCE_CHUNKS_1, REFERENCE_CHUNKS_2]) + def test_stream_chunk_matches_reference(self, tokenizer, response_parser, reference_chunks): + """Feed the real streaming sequence into ResponseParser.stream_chunk + and verify each parsed chunk. + + Input: + - Strictly use the reference token stream (including , \\n, <, + function, =get, ...). + + Checks: + - reasoning: whenever an expected reasoning chunk is provided, the + parser must emit exactly that reasoning_content. + - content: only after , we expect a single \\n\\n. + - tool_calls: + - for each step, tool_emitted must match expected_tool_emitted; + - whenever ResponseParser actually emits DeltaToolCall, we check: + - the first time a function.name appears, it must equal + get_current_temperature; + - any function.arguments increments are concatenated and validated + after streaming completes. + """ + + for (delta_text, exp_delta_msg, exp_reasoning, exp_content, exp_tool_emitted, + exp_function_name, exp_function_arguments, + exp_type) in reference_chunks: + delta_ids = self._encode_ids(tokenizer, delta_text) + delta_msg, tool_emitted = response_parser.stream_chunk( + delta_text=delta_text, + delta_token_ids=delta_ids, + ) + if not exp_delta_msg: + assert delta_msg is None + continue + # reasoning: when an expected reasoning chunk is provided, it must match exactly. + assert delta_msg.reasoning_content == exp_reasoning + assert delta_msg.content == exp_content + assert tool_emitted == exp_tool_emitted + if tool_emitted: + assert delta_msg.tool_calls is not None + assert len(delta_msg.tool_calls) == 1 + call = delta_msg.tool_calls[0] + assert isinstance(call, DeltaToolCall) + assert call.type == exp_type + assert call.function is not None + assert call.function.name == exp_function_name + assert call.function.arguments == exp_function_arguments + + def test_stream_chunk_handles_mixed_reasoning_content_tool(self, tokenizer, response_parser): + """A single delta may contain reasoning/content/tool segments together. + + This test covers chunk shapes: + 1) ```` + 2) `` Let me think `` + 3) ``The answer is 9 OK. The`` + 4) ``fine. \\n\\n `` + """ + + def _call(delta_text: str): + ids = self._encode_ids(tokenizer, delta_text) + return response_parser.stream_chunk(delta_text=delta_text, delta_token_ids=ids) + + # 1) tag-only chunk should be swallowed + delta_msg, tool_emitted = _call('') + assert delta_msg is None + assert tool_emitted is False + + # 2) open-think plus reasoning text should emit only reasoning + delta_msg, tool_emitted = _call(' Let me think ') + assert delta_msg is not None + assert delta_msg.reasoning_content == ' Let me think ' + assert delta_msg.content is None + assert tool_emitted is False + + # 3) chunk carries reasoning end + normal content. + # New parser emits ordered events, so this call emits reasoning first. + delta_msg, tool_emitted = _call('The answer is 9 OK. The') + assert delta_msg is not None + assert delta_msg.reasoning_content == 'The answer is 9 ' + assert delta_msg.content is None + assert tool_emitted is False + + # Next call flushes queued plain content from previous chunk first. + delta_msg, tool_emitted = _call('fine. \n\n ') + assert delta_msg is not None + assert delta_msg.reasoning_content is None + assert delta_msg.content == ' OK. The' + assert tool_emitted is False + + # Flush the next queued plain segment from chunk-4. + delta_msg, tool_emitted = _call('') + assert delta_msg is not None + # Stray closing tag after reasoning has ended is treated as plain content. + assert delta_msg.reasoning_content is None + assert delta_msg.content == 'fine. \n\n ' + assert tool_emitted is False + + def test_stream_chunk_tool_enabled_without_reasoning_parser(self, tokenizer): + """When reasoning parser is disabled, tool parsing still works. + + This proves the tool branch is reachable from plain mode after seeing the tool open tag, even with no reasoning + parser configured. + """ + cls = ResponseParserManager.get('default') + old_reasoning_cls = cls.reasoning_parser_cls + old_tool_cls = cls.tool_parser_cls + try: + cls.reasoning_parser_cls = None + cls.tool_parser_cls = ToolParserManager.get('qwen3') + + request = ChatCompletionRequest( + model=MODEL_ID, + messages=[], + stream=True, + tool_choice='auto', + chat_template_kwargs={'enable_thinking': False}, + ) + parser = cls(request=request, tokenizer=tokenizer) + + chunks = [ + 'prefix ', + '', + '\n', + '{"', + 'name', + '":', + ' "', + 'get', + '_weather', + '",', + ] + tool_seen = False + for chunk in chunks: + delta_ids = self._encode_ids(tokenizer, chunk) + delta_msg, tool_emitted = parser.stream_chunk(delta_text=chunk, delta_token_ids=delta_ids) + if delta_msg is not None: + assert delta_msg.reasoning_content is None + if tool_emitted: + tool_seen = True + assert delta_msg is not None + assert delta_msg.tool_calls is not None + assert delta_msg.tool_calls[0].function is not None + assert delta_msg.tool_calls[0].function.name == 'get_weather' + assert tool_seen is True + finally: + cls.reasoning_parser_cls = old_reasoning_cls + cls.tool_parser_cls = old_tool_cls + + def test_stream_chunk_reasoning_without_open_tag(self, tokenizer, response_parser): + """Qwen thinking mode may omit ```` and start directly with + reasoning. + + In this case, chunks before ```` must be emitted as + ``reasoning_content``. + """ + + def _call(delta_text: str): + delta_ids = self._encode_ids(tokenizer, delta_text) + return response_parser.stream_chunk(delta_text=delta_text, delta_token_ids=delta_ids) + + # No opening tag, but still in reasoning mode initially. + delta_msg, tool_emitted = _call('Let me reason ') + assert delta_msg is not None + assert delta_msg.reasoning_content == 'Let me reason ' + assert delta_msg.content is None + assert tool_emitted is False + + delta_msg, tool_emitted = _call('step by step') + assert delta_msg is not None + assert delta_msg.reasoning_content == 'step by step' + assert delta_msg.content is None + assert tool_emitted is False + + # Closing tag chunk itself is swallowed. + delta_msg, tool_emitted = _call('') + assert delta_msg is None + assert tool_emitted is False + + # After close tag, emit normal content. + delta_msg, tool_emitted = _call(' final answer') + assert delta_msg is not None + assert delta_msg.reasoning_content is None + assert delta_msg.content == ' final answer' + assert tool_emitted is False + + def test_stream_chunk_preserves_order(self, tokenizer, response_parser): + """Mixed single chunk should preserve event order without content + merge.""" + class PlainStartQwenReasoningParser(ReasoningParserManager.get('default')): + + def starts_in_reasoning_mode(self) -> bool: + return False + + cls = ResponseParserManager.get('default') + old_reasoning_cls = cls.reasoning_parser_cls + old_tool_cls = cls.tool_parser_cls + try: + cls.reasoning_parser_cls = PlainStartQwenReasoningParser + cls.tool_parser_cls = ToolParserManager.get('qwen3') + request = ChatCompletionRequest( + model=MODEL_ID, + messages=[], + stream=True, + tool_choice='auto', + chat_template_kwargs={'enable_thinking': True}, + ) + parser = cls(request=request, tokenizer=tokenizer) + + delta_text = 'content-xxx reasoning-yyy content-zzz ' + delta_ids = self._encode_ids(tokenizer, delta_text) + + # 1st event: plain content before + delta_msg, tool_emitted = parser.stream_chunk(delta_text=delta_text, delta_token_ids=delta_ids) + assert delta_msg is not None + assert delta_msg.content == 'content-xxx ' + assert delta_msg.reasoning_content is None + assert tool_emitted is False + + # 2nd event: reasoning segment + delta_msg, tool_emitted = parser.stream_chunk(delta_text='', delta_token_ids=[]) + assert delta_msg is not None + assert delta_msg.content is None + assert delta_msg.reasoning_content == ' reasoning-yyy ' + assert tool_emitted is False + + # 3rd event: trailing content segment before + delta_msg, tool_emitted = parser.stream_chunk(delta_text='', delta_token_ids=[]) + assert delta_msg is not None + assert delta_msg.content == ' content-zzz ' + assert delta_msg.reasoning_content is None + assert tool_emitted is False + finally: + cls.reasoning_parser_cls = old_reasoning_cls + cls.tool_parser_cls = old_tool_cls diff --git a/tests/test_lmdeploy/test_harmony_gpt_oss_parser.py b/tests/test_lmdeploy/test_harmony_gpt_oss_parser.py deleted file mode 100644 index 7624ff4d17..0000000000 --- a/tests/test_lmdeploy/test_harmony_gpt_oss_parser.py +++ /dev/null @@ -1,328 +0,0 @@ -import collections -import json -import os -import sys -import time -import types -from collections.abc import Generator - -import pytest -import shortuuid - -# Ensure local package is imported (not any site-packages installation) -REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) -if REPO_ROOT not in sys.path: - sys.path.insert(0, REPO_ROOT) - - -def _install_openai_harmony_stub(): - """Install a minimal stub for `openai_harmony` so the module imports - without the real dependency. - - The GptOssChatParser test injects its own dummy parser, so the stub is sufficient. - """ - if 'openai_harmony' in sys.modules: - return - m = types.ModuleType('openai_harmony') - - class HarmonyEncodingName: - HARMONY_GPT_OSS = 'HARMONY_GPT_OSS' - - class Role: - ASSISTANT = 'assistant' - - class StreamableParser: # pragma: no cover - constructor only used - - def __init__(self, encoding, role=None): - self.encoding = encoding - self.role = role - - def load_harmony_encoding(name): # pragma: no cover - not used in test - return object() - - m.HarmonyEncodingName = HarmonyEncodingName - m.Role = Role - m.StreamableParser = StreamableParser - m.load_harmony_encoding = load_harmony_encoding - sys.modules['openai_harmony'] = m - - -TestExpects = collections.namedtuple('TestExpects', 'func_name location') - - -class DummyParser: - """A minimal stand-in for Harmony's StreamableParser with channels. - - Control tokens: - -1: start functions.get_weather (commentary) - -4: start functions.get_time (commentary) - -6: start functions.get_weather (again) - -9: end current tool call, append to `messages` - -2: switch to final (visible) content - -3: switch to analysis (reasoning) - Other tokens are interpreted as chr(token). - """ - - class _Msg: - - def __init__(self, channel, recipient): - self.channel = channel - self.recipient = recipient - - def __init__(self): - self.current_channel = None - self.current_recipient = None - self.last_content_delta = '' - self.messages = [] - - def process(self, token): - if token == -1: - self.current_channel = 'commentary' - self.current_recipient = 'functions.get_weather' - self.last_content_delta = '' - return - if token == -4: - self.current_channel = 'commentary' - self.current_recipient = 'functions.get_time' - self.last_content_delta = '' - return - if token == -6: - self.current_channel = 'commentary' - self.current_recipient = 'functions.get_weather' - self.last_content_delta = '' - return - if token == -9: - if self.current_channel == 'commentary' and self.current_recipient and self.current_recipient.startswith( - 'functions.'): - self.messages.append(self._Msg(self.current_channel, self.current_recipient)) - # reset recipient to signal end of current tool call - self.current_recipient = None - self.current_channel = None - self.last_content_delta = '' - return - if token == -2: - self.current_channel = 'final' - self.current_recipient = None - self.last_content_delta = '' - return - if token == -3: - self.current_channel = 'analysis' - self.current_recipient = None - self.last_content_delta = '' - return - # regular character token - self.last_content_delta = chr(token) - - -def _chat_completion_v1(request, token_chunks: list[list[int]]): - from lmdeploy.serve.openai.harmony_utils import GptOssChatParser - from lmdeploy.serve.openai.protocol import ( - ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, - UsageInfo, - ) - - request_id = f'chat-{shortuuid.random()}' - created_time = int(time.time()) - model_name = request.model - - parser = GptOssChatParser() - parser.parser = DummyParser() - - if request.stream: - - def completion_stream_generator() -> Generator['ChatCompletionStreamResponse', None, None]: - finish_reason = 'stop' - for chunk in token_chunks: - delta_message = parser.parse_streaming(chunk) - choice_data = ChatCompletionResponseStreamChoice(index=0, - delta=delta_message, - finish_reason=finish_reason, - logprobs=None) - response = ChatCompletionStreamResponse(id=request_id, - created=created_time, - model=model_name, - choices=[choice_data], - usage=None) - yield response - - return completion_stream_generator() - - # Non-stream path: parse all tokens at once using parse_full - tokens: list[int] = [] - for c in token_chunks: - tokens.extend(c) - message = parser.parse_full(tokens) - finish_reason = 'tool_calls' if message.tool_calls else 'stop' - choice_data = ChatCompletionResponseChoice(index=0, message=message, finish_reason=finish_reason) - return ChatCompletionResponse(id=request_id, - created=created_time, - model=model_name, - choices=[choice_data], - usage=UsageInfo()) - - -def _stream_parse(request, token_chunks: list[list[int]]): - from lmdeploy.serve.openai.protocol import DeltaMessage - - content = '' - reasoning_content = '' - tool_calls_by_index = {} - - for i, stream_resp in enumerate(_chat_completion_v1(request, token_chunks)): - delta_message: DeltaMessage = stream_resp.choices[0].delta - if delta_message.content: - content += delta_message.content - if delta_message.reasoning_content: - reasoning_content += delta_message.reasoning_content - if delta_message.tool_calls: - for c in delta_message.tool_calls: - idx = c.index - existing_call = tool_calls_by_index.get(idx, None) - if not existing_call: - tool_calls_by_index[idx] = c - continue - if c.function.name: - existing_call.function.name = c.function.name - if c.function.arguments: - existing_call.function.arguments = existing_call.function.arguments or '' - existing_call.function.arguments += c.function.arguments - # sorted list for stable order - tool_calls = [tool_calls_by_index[i] for i in sorted(tool_calls_by_index.keys())] - return content, reasoning_content, tool_calls - - -def _t(s: str) -> list[int]: - return [ord(c) for c in s] - - -# Basic: single function call split across two chunks (bug repro scenario) -TOKENS_SINGLE_CALL_TWO_CHUNKS = [ - [-1] + _t('{"location": "Paris'), - _t(', France"}'), -] - -# Multiple calls with indices and different function names -TOKENS_TWO_CALLS_DIFFERENT_FUNCS = [ - [-1] + _t('{"location": "Berlin"}') + [-9] + [-4] + _t('{"city": "New'), - _t(' York"}') + [-9], -] - -# Interleaved channels: analysis, tool call, final content -TOKENS_INTERLEAVED = [ - [-3] + _t('Thinking about the weather. ') + [-1] + _t('{"location": "Par'), - _t('is, France"}') + [-9] + [-2] + _t('Fetching the weather now.'), -] - -# Two calls, same function name, indices increment -TOKENS_TWO_CALLS_SAME_FUNC = [ - [-1] + _t('{"location": "Tokyo"}') + [-9], - [-6] + _t('{"location": "Ky'), - _t('oto"}') + [-9], -] - - -@pytest.mark.parametrize(('token_chunks', 'expects'), [ - (TOKENS_SINGLE_CALL_TWO_CHUNKS, [TestExpects('get_weather', 'Paris, France')]), -]) -def test_parser_stream_basic(token_chunks: list[list[int]], expects: list[TestExpects]): - from lmdeploy.serve.openai.protocol import ChatCompletionRequest - - _install_openai_harmony_stub() - request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True) - content, reasoning_content, tool_calls = _stream_parse(request, token_chunks) - - assert len(tool_calls) == len(expects) - for parsed_call, expected_call in zip(tool_calls, expects): - assert parsed_call.function.name == expected_call.func_name - args = json.loads(parsed_call.function.arguments) - assert args['location'] == expected_call.location - assert content.strip() == '' - assert (reasoning_content or '').strip() == '' - - -def test_parser_stream_multiple_calls_indices(): - from lmdeploy.serve.openai.protocol import ChatCompletionRequest - - _install_openai_harmony_stub() - request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True) - content, reasoning_content, tool_calls = _stream_parse(request, TOKENS_TWO_CALLS_DIFFERENT_FUNCS) - - assert len(tool_calls) == 2 - # tool_calls sorted by index ensures stable order - tc0, tc1 = tool_calls - assert tc0.index == 0 and tc1.index == 1 - assert tc0.function.name == 'get_weather' - assert json.loads(tc0.function.arguments)['location'] == 'Berlin' - assert tc1.function.name == 'get_time' - assert json.loads(tc1.function.arguments)['city'] == 'New York' - assert (content or '').strip() == '' - assert (reasoning_content or '').strip() == '' - - -def test_parser_stream_interleaved_channels(): - from lmdeploy.serve.openai.protocol import ChatCompletionRequest - - _install_openai_harmony_stub() - request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True) - content, reasoning_content, tool_calls = _stream_parse(request, TOKENS_INTERLEAVED) - - assert json.loads(tool_calls[0].function.arguments)['location'] == 'Paris, France' - assert reasoning_content == 'Thinking about the weather. ' - assert content == 'Fetching the weather now.' - - -@pytest.mark.parametrize(('token_chunks', 'expects'), [ - (TOKENS_TWO_CALLS_SAME_FUNC, [TestExpects('get_weather', 'Tokyo'), - TestExpects('get_weather', 'Kyoto')]), -]) -def test_parser_stream_two_calls_same_func(token_chunks: list[list[int]], expects: list[TestExpects]): - from lmdeploy.serve.openai.protocol import ChatCompletionRequest - - _install_openai_harmony_stub() - request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True) - _, _, tool_calls = _stream_parse(request, token_chunks) - - assert len(tool_calls) == len(expects) - for parsed_call, expected_call in zip(tool_calls, expects): - assert parsed_call.function.name == expected_call.func_name - args = json.loads(parsed_call.function.arguments) - assert args['location'] == expected_call.location - - -def test_open_tool_call_no_args(): - from lmdeploy.serve.openai.protocol import ChatCompletionRequest - - _install_openai_harmony_stub() - request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True) - content, reasoning_content, tool_calls = _stream_parse(request, [[-1]]) - - assert len(tool_calls) == 1 - assert tool_calls[0].function.name == 'get_weather' - assert (tool_calls[0].function.arguments or '') == '' - assert (content or '') == '' - assert (reasoning_content or '') == '' - - -@pytest.mark.parametrize(('token_chunks', 'expects'), [ - (TOKENS_SINGLE_CALL_TWO_CHUNKS, [TestExpects('get_weather', 'Paris, France')]), - (TOKENS_TWO_CALLS_SAME_FUNC, [TestExpects('get_weather', 'Tokyo'), - TestExpects('get_weather', 'Kyoto')]), -]) -def test_parser_nonstream(token_chunks: list[list[int]], expects: list[TestExpects]): - from lmdeploy.serve.openai.protocol import ChatCompletionRequest - - _install_openai_harmony_stub() - resp = _chat_completion_v1(ChatCompletionRequest(model='gpt-oss', messages=[], stream=False), token_chunks) - - assert len(resp.choices) == 1 - first_message = resp.choices[0].message - assert first_message.content is None - assert (first_message.reasoning_content or '') == '' - assert len(first_message.tool_calls) == len(expects) - for parsed_call, expected_call in zip(first_message.tool_calls, expects): - assert parsed_call.function.name == expected_call.func_name - args = json.loads(parsed_call.function.arguments) - assert args['location'] == expected_call.location diff --git a/tests/test_lmdeploy/test_qwen3_parser.py b/tests/test_lmdeploy/test_qwen3_parser.py deleted file mode 100644 index b3d52b47b6..0000000000 --- a/tests/test_lmdeploy/test_qwen3_parser.py +++ /dev/null @@ -1,368 +0,0 @@ -import collections -import json -import time -from collections.abc import Generator - -import pytest -import shortuuid - -from lmdeploy.serve.openai.api_server import VariableInterface -from lmdeploy.serve.openai.protocol import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, - ChatMessage, - DeltaMessage, - DeltaToolCall, - UsageInfo, -) -from lmdeploy.serve.openai.reasoning_parser.qwen_qwq_reasoning_parser import QwenQwQReasoningParser -from lmdeploy.serve.openai.tool_parser.qwen3_parser import Qwen3ToolParser - -TestExpects = collections.namedtuple('TestExpects', 'func_name location') - - -class DummyTokenizer: - - def decode(self, token_ids: list[int]) -> str: - return ' '.join(map(str, token_ids)) - - def encode(self, text: str) -> list[int]: - return [ord(c) for c in text] - - -DELTA_TEXT_SEQUENCE = [ - '', - '\n', - '好的', - ',', - '用户', - '问', - '的是', - '北京', - '的', - '天气', - '怎么样', - '。', - '我', - '需要', - '调', - '用', - 'get', - '_weather', - '这个', - '工具', - '来', - '获取', - '信息', - '。', - '首先', - ',', - '确认', - '用户', - '提供的', - '地点', - '是', - '北京', - ',', - '参数', - '正确', - '。', - '然后', - '检查', - '工具', - '的', - '参数', - '要求', - ',', - '只需要', - 'location', - ',', - '类型', - '是', - '字符串', - '。', - '于是', - '构造', - '参数', - '对象', - ',', - '调', - '用', - '函数', - ',', - '返回', - '结果', - '。', - '确保', - '没有', - '遗漏', - '必要', - '参数', - ',', - '比如', - 'location', - '是', - '必须', - '的', - ',', - '这里', - '已经', - '提供', - ',', - '所以', - '没问题', - '。', - '最后', - '将', - '结果', - '以', - '自然', - '语言', - '回复', - '用户', - '。\n', - '', - '\n\n', - '', - '\n', - '{"', - 'name', - '":', - ' "', - 'get', - '_weather', - '",', - ' "', - 'arguments', - '":', - ' {"', - 'location', - '":', - ' "', - '北京', - '"}}\n', - '', -] - -DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS = DELTA_TEXT_SEQUENCE + [ - '\n\n', - '', - '\n', - '{"', - 'name', - '":', - ' "', - 'get', - '_weather', - '",', - ' "', - 'arguments', - '":', - ' {"', - 'location', - '":', - ' "', - '上海', - '"}}\n', - '', -] - -EXPECTED_CONTENT = '' -EXPECTED_REASONING_CONTENT = ''.join(( - '好的,用户问的是北京的天气怎么样。我需要调用get_weather这个工具来获取信息。', - '首先,确认用户提供的地点是北京,参数正确。然后检查工具的参数要求,', - '只需要location,类型是字符串。于是构造参数对象,调用函数,返回结果。', - '确保没有遗漏必要参数,比如location是必须的,这里已经提供,所以没问题。', - '最后将结果以自然语言回复用户。', -)) - - -def _chat_completion_v1( - request: ChatCompletionRequest, - text_sequence: list[str]) -> ChatCompletionResponse | Generator[ChatCompletionStreamResponse, None, None]: - request_id = f'chat-{shortuuid.random()}' - created_time = int(time.time()) - model_name = request.model - if request.stream: - - def completion_stream_generator() -> Generator[ChatCompletionStreamResponse, None, None]: - previous_text = '' - current_text = '' - finish_reason = 'stop' - has_parser = VariableInterface.tool_parser is not None or VariableInterface.reasoning_parser is not None - for text in text_sequence: - logprobs, usage = None, None - delta_message = DeltaMessage(role='assistant', content=text) - if has_parser: - current_text = current_text + text - if request.tool_choice != 'none' and VariableInterface.tool_parser is not None: - tool_delta = VariableInterface.tool_parser.extract_tool_calls_streaming( - previous_text=previous_text, - current_text=current_text, - delta_text=delta_message.content, - previous_token_ids=[], - current_token_ids=[], - delta_token_ids=[], - request=request) - if tool_delta is not None: - delta_message.tool_calls = tool_delta.tool_calls - delta_message.content = tool_delta.content or '' - if VariableInterface.reasoning_parser is not None: - reasoning_delta = VariableInterface.reasoning_parser.extract_reasoning_content_streaming( - previous_text=previous_text, - current_text=current_text, - delta_text=delta_message.content, - previous_token_ids=[], - current_token_ids=[], - delta_token_ids=[]) - if reasoning_delta is not None: - delta_message.reasoning_content = reasoning_delta.reasoning_content - delta_message.content = reasoning_delta.content or '' - if has_parser: - previous_text = current_text - - choice_data = ChatCompletionResponseStreamChoice(index=0, - delta=delta_message, - finish_reason=finish_reason, - logprobs=logprobs) - response = ChatCompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[choice_data], - usage=usage, - ) - yield response - - return completion_stream_generator() - - # copied and simplified from api_server.py:chat_completions_v1 - text = ''.join(text_sequence) - tool_calls = None - reasoning_content = None - finish_reason = 'stop' - if request.tool_choice != 'none' and VariableInterface.tool_parser is not None: - tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request) - text, tool_calls = tool_call_info.content, tool_call_info.tool_calls - if isinstance(tool_calls, list) and len(tool_calls): - if finish_reason == 'stop': - finish_reason = 'tool_calls' - - if VariableInterface.reasoning_parser is not None: - reasoning_content, text = VariableInterface.reasoning_parser.extract_reasoning_content(text, request) - - choices = [] - choice_data = ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role='assistant', content=text, tool_calls=tool_calls, reasoning_content=reasoning_content), - finish_reason=finish_reason, - ) - choices.append(choice_data) - - return ChatCompletionResponse( - id=request_id, - created=created_time, - model=model_name, - choices=choices, - usage=UsageInfo(), - ) - - -def _stream_parse(request: ChatCompletionRequest, text_sequence: list[str]) -> tuple[str, str, list[DeltaToolCall]]: - # Call parser.extract_tool_calls_streaming with delta_text specified in `DELTA_TEXT_SEQUENCE`. - # `current_text` and `previous_text` init values and update logic - # can be found in lmdeploy/serve/openai/api_server.py:455-523. - content = '' - reasoning_content = '' - tool_calls = {} - - for stream_resp in _chat_completion_v1(request, text_sequence): - delta_message: DeltaMessage = stream_resp.choices[0].delta - if delta_message.content: - content += delta_message.content - if delta_message.reasoning_content: - reasoning_content += delta_message.reasoning_content - if delta_message.tool_calls: - for c in delta_message.tool_calls: - existing_call = tool_calls.get(c.id, None) - if not existing_call: - tool_calls[c.id] = c - continue - # merge with existing - if c.function.name: - existing_call.function.name = c.function.name - if c.function.arguments: - existing_call.function.arguments = existing_call.function.arguments or '' - existing_call.function.arguments += c.function.arguments - return content, reasoning_content, list(sorted(tool_calls.values(), key=lambda x: x.index)) - - -@pytest.mark.parametrize(('text_sequence', 'expects'), [ - (DELTA_TEXT_SEQUENCE, [TestExpects('get_weather', '北京')]), - (DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS, [TestExpects('get_weather', '北京'), - TestExpects('get_weather', '上海')]), -]) -def test_parser_stream(text_sequence: list[str], expects: list[TestExpects]): - tokenizer = DummyTokenizer() - VariableInterface.tool_parser = Qwen3ToolParser(tokenizer=tokenizer) - VariableInterface.reasoning_parser = QwenQwQReasoningParser(tokenizer=tokenizer) - request = ChatCompletionRequest(model='qwen', messages=[], stream=True) - content, reasoning_content, tool_calls = _stream_parse(request, text_sequence) - assert len(tool_calls) == len(expects) - for parsed_call, expected_call in zip(tool_calls, expects): - assert parsed_call.function.name == expected_call.func_name - args = json.loads(parsed_call.function.arguments) - assert args['location'] == expected_call.location - assert content.strip() == EXPECTED_CONTENT - assert reasoning_content.strip() == EXPECTED_REASONING_CONTENT - - -@pytest.mark.parametrize(('text_sequence', 'expects'), [ - (DELTA_TEXT_SEQUENCE, [TestExpects('get_weather', '北京')]), - (DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS, [TestExpects('get_weather', '北京'), - TestExpects('get_weather', '上海')]), -]) -def test_parser_nonstream(text_sequence: list[str], expects: list[TestExpects]): - tokenizer = DummyTokenizer() - VariableInterface.tool_parser = Qwen3ToolParser(tokenizer=tokenizer) - VariableInterface.reasoning_parser = QwenQwQReasoningParser(tokenizer=tokenizer) - resp: ChatCompletionResponse = _chat_completion_v1(ChatCompletionRequest(model='qwen', messages=[], stream=False), - text_sequence) - - assert len(resp.choices) == 1 - first_message = resp.choices[0].message - assert first_message.content is None - assert first_message.reasoning_content == EXPECTED_REASONING_CONTENT - assert len(first_message.tool_calls) == len(expects) - for parsed_call, expected_call in zip(first_message.tool_calls, expects): - assert parsed_call.function.name == expected_call.func_name - args = json.loads(parsed_call.function.arguments) - assert args['location'] == expected_call.location - - -def test_no_think_nonstream(): - text_sequence = [ - '你好', - '呀', - '!', - '✨', - '', - ' 很', - '高兴', - '见到', - '你', - '!', - ] - tokenizer = DummyTokenizer() - VariableInterface.tool_parser = Qwen3ToolParser(tokenizer=tokenizer) - VariableInterface.reasoning_parser = QwenQwQReasoningParser(tokenizer=tokenizer) - resp: ChatCompletionResponse = _chat_completion_v1(ChatCompletionRequest(model='qwen', messages=[], stream=False), - text_sequence) - - assert len(resp.choices) == 1 - first_message = resp.choices[0].message - assert first_message.content == '你好呀!✨ 很高兴见到你!' - assert first_message.reasoning_content is None diff --git a/tests/test_lmdeploy/test_qwen3coder_parser.py b/tests/test_lmdeploy/test_qwen3coder_parser.py deleted file mode 100644 index 291459690c..0000000000 --- a/tests/test_lmdeploy/test_qwen3coder_parser.py +++ /dev/null @@ -1,412 +0,0 @@ -import collections -import json -import time -from collections.abc import Generator - -import pytest -import shortuuid - -from lmdeploy.model import MODELS -from lmdeploy.serve.openai.api_server import VariableInterface -from lmdeploy.serve.openai.protocol import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, - ChatMessage, - DeltaMessage, - DeltaToolCall, - UsageInfo, -) -from lmdeploy.serve.openai.tool_parser.qwen3coder_parser import Qwen3CoderToolParser - -TestExpects = collections.namedtuple('TestExpects', 'func_name kwargs') - - -class DummyTokenizer: - - def decode(self, token_ids: list[int]) -> str: - return ' '.join(map(str, token_ids)) - - def encode(self, text: str) -> list[int]: - return [ord(c) for c in text] - - -DELTA_TEXT_SEQUENCE = [ - '好的,我现在帮你调用工具。\n', - '', - '\n', - '\n', - '', - '北京\n', - 'celsius\n', - '\n', - '', -] - -DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS = DELTA_TEXT_SEQUENCE + [ - '\n\n', - '', - '\n\n', - '上海\n', - '\n', - '', -] - -EXPECTED_CONTENT = '好的,我现在帮你调用工具。' - - -def _chat_completion_v1( - request: ChatCompletionRequest, - text_sequence: list[str]) -> ChatCompletionResponse | Generator[ChatCompletionStreamResponse, None, None]: - request_id = f'chat-{shortuuid.random()}' - created_time = int(time.time()) - model_name = request.model - if request.stream: - - def completion_stream_generator() -> Generator[ChatCompletionStreamResponse, None, None]: - previous_text = '' - current_text = '' - finish_reason = 'stop' - has_parser = (VariableInterface.tool_parser is not None or VariableInterface.reasoning_parser is not None) - for text in text_sequence: - logprobs, usage = None, None - delta_message = DeltaMessage(role='assistant', content=text) - if has_parser: - current_text = current_text + text - has_tool = VariableInterface.tool_parser is not None - if request.tool_choice != 'none' and has_tool: - tool_delta = VariableInterface.tool_parser.extract_tool_calls_streaming( - previous_text=previous_text, - current_text=current_text, - delta_text=delta_message.content, - previous_token_ids=[], - current_token_ids=[], - delta_token_ids=[], - request=request) - if tool_delta is not None: - delta_message.tool_calls = tool_delta.tool_calls - delta_message.content = tool_delta.content or '' - if VariableInterface.reasoning_parser is not None: - parser = VariableInterface.reasoning_parser - reasoning_delta = parser.extract_reasoning_content_streaming(previous_text=previous_text, - current_text=current_text, - delta_text=delta_message.content, - previous_token_ids=[], - current_token_ids=[], - delta_token_ids=[]) - if reasoning_delta is not None: - delta_message.reasoning_content = (reasoning_delta.reasoning_content) - delta_message.content = reasoning_delta.content or '' - if has_parser: - previous_text = current_text - - choice_data = ChatCompletionResponseStreamChoice(index=0, - delta=delta_message, - finish_reason=finish_reason, - logprobs=logprobs) - response = ChatCompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[choice_data], - usage=usage, - ) - yield response - - return completion_stream_generator() - - text = ''.join(text_sequence) - tool_calls = None - reasoning_content = None - finish_reason = 'stop' - has_tool = VariableInterface.tool_parser is not None - if request.tool_choice != 'none' and has_tool: - tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request) - text, tool_calls = tool_call_info.content, tool_call_info.tool_calls - if isinstance(tool_calls, list) and len(tool_calls): - if finish_reason == 'stop': - finish_reason = 'tool_calls' - - if VariableInterface.reasoning_parser is not None: - parser = VariableInterface.reasoning_parser - reasoning_content, text = parser.extract_reasoning_content(text, request) - - choices = [] - choice_data = ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role='assistant', content=text, tool_calls=tool_calls, reasoning_content=reasoning_content), - finish_reason=finish_reason, - ) - choices.append(choice_data) - - return ChatCompletionResponse( - id=request_id, - created=created_time, - model=model_name, - choices=choices, - usage=UsageInfo(), - ) - - -def _stream_parse(request: ChatCompletionRequest, text_sequence: list[str]) -> tuple[str, str, list[DeltaToolCall]]: - content = '' - reasoning_content = '' - tool_calls = {} - - for stream_resp in _chat_completion_v1(request, text_sequence): - delta_message: DeltaMessage = stream_resp.choices[0].delta - if delta_message.content: - content += delta_message.content - if delta_message.reasoning_content: - reasoning_content += delta_message.reasoning_content - if delta_message.tool_calls: - for c in delta_message.tool_calls: - existing_call = tool_calls.get(c.id, None) - if not existing_call: - tool_calls[c.id] = c - continue - # merge with existing - if c.function.name: - existing_call.function.name = c.function.name - if c.function.arguments: - existing_call.function.arguments = (existing_call.function.arguments or '') - existing_call.function.arguments += c.function.arguments - return content, reasoning_content, list(sorted(tool_calls.values(), key=lambda x: x.index)) - - -@pytest.mark.parametrize(('text_sequence', 'expects'), [ - (DELTA_TEXT_SEQUENCE, [TestExpects('get_weather', { - 'location': '北京', - 'unit': 'celsius' - })]), - (DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS, [ - TestExpects('get_weather', { - 'location': '北京', - 'unit': 'celsius' - }), - TestExpects('get_weather', {'location': '上海'}) - ]), -]) -def test_parser_stream(text_sequence: list[str], expects: list[TestExpects]): - tokenizer = DummyTokenizer() - VariableInterface.tool_parser = Qwen3CoderToolParser(tokenizer=tokenizer) - VariableInterface.reasoning_parser = None - request = ChatCompletionRequest(model='qwen3coder', messages=[], stream=True) - content, reasoning_content, tool_calls = _stream_parse(request, text_sequence) - assert len(tool_calls) == len(expects) - for parsed_call, expected_call in zip(tool_calls, expects): - assert parsed_call.function.name == expected_call.func_name - args = json.loads(parsed_call.function.arguments) - assert args == expected_call.kwargs - assert content.strip() == EXPECTED_CONTENT - - -@pytest.mark.parametrize(('text_sequence', 'expects'), [ - (DELTA_TEXT_SEQUENCE, [TestExpects('get_weather', { - 'location': '北京', - 'unit': 'celsius' - })]), - (DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS, [ - TestExpects('get_weather', { - 'location': '北京', - 'unit': 'celsius' - }), - TestExpects('get_weather', {'location': '上海'}) - ]), -]) -def test_parser_nonstream(text_sequence: list[str], expects: list[TestExpects]): - tokenizer = DummyTokenizer() - VariableInterface.tool_parser = Qwen3CoderToolParser(tokenizer=tokenizer) - VariableInterface.reasoning_parser = None - resp: ChatCompletionResponse = _chat_completion_v1( - ChatCompletionRequest(model='qwen3coder', messages=[], stream=False), text_sequence) - - assert len(resp.choices) == 1 - first_message = resp.choices[0].message - assert first_message.content.strip() == EXPECTED_CONTENT - assert first_message.reasoning_content is None - assert len(first_message.tool_calls) == len(expects) - for parsed_call, expected_call in zip(first_message.tool_calls, expects): - assert parsed_call.function.name == expected_call.func_name - args = json.loads(parsed_call.function.arguments) - assert args == expected_call.kwargs - - -def test_no_think_nonstream(): - text_sequence = [ - '你好', - '呀', - '!', - '✨', - '', - ' 很', - '高兴', - '见到', - '你', - '!', - ] - tokenizer = DummyTokenizer() - VariableInterface.tool_parser = Qwen3CoderToolParser(tokenizer=tokenizer) - VariableInterface.reasoning_parser = None - resp: ChatCompletionResponse = _chat_completion_v1( - ChatCompletionRequest(model='qwen3coder', messages=[], stream=False), text_sequence) - - assert len(resp.choices) == 1 - first_message = resp.choices[0].message - assert first_message.content == '你好呀!✨ 很高兴见到你!' - assert first_message.reasoning_content is None - - -def test_adjust_request_parses_assistant_tool_call_object_arguments(): - parser = Qwen3CoderToolParser(tokenizer=DummyTokenizer()) - request = ChatCompletionRequest(model='qwen3coder', - messages=[{ - 'role': 'user', - 'content': 'hello' - }, { - 'role': 'assistant', - 'content': '', - 'tool_calls': [{ - 'id': 'call_1', - 'type': 'function', - 'function': { - 'name': 'get_weather', - 'arguments': '{"city": "Paris", "units": "metric"}' - } - }] - }]) - - adjusted_request = parser.adjust_request(request) - - assert adjusted_request is not request - assert adjusted_request.messages is not request.messages - assert adjusted_request.messages[1] is not request.messages[1] - assert adjusted_request.messages[1]['tool_calls'][0] is not request.messages[1]['tool_calls'][0] - assert adjusted_request.messages[1]['tool_calls'][0]['function']['arguments'] == { - 'city': 'Paris', - 'units': 'metric' - } - assert request.messages[1]['tool_calls'][0]['function']['arguments'] == '{"city": "Paris", "units": "metric"}' - - -@pytest.mark.parametrize('arguments', ['[1, 2, 3]', '1', '{not valid json}']) -def test_adjust_request_leaves_non_mapping_arguments_unchanged(arguments): - parser = Qwen3CoderToolParser(tokenizer=DummyTokenizer()) - request = ChatCompletionRequest(model='qwen3coder', - messages=[{ - 'role': 'assistant', - 'content': '', - 'tool_calls': [{ - 'id': 'call_1', - 'type': 'function', - 'function': { - 'name': 'fn', - 'arguments': arguments - } - }] - }]) - - adjusted_request = parser.adjust_request(request) - - assert adjusted_request is request - - -def test_adjust_request_noops_for_string_messages(): - parser = Qwen3CoderToolParser(tokenizer=DummyTokenizer()) - request = ChatCompletionRequest(model='qwen3coder', messages='hello') - - adjusted_request = parser.adjust_request(request) - - assert adjusted_request is request - - -def test_adjust_request_noops_without_assistant_tool_calls(): - parser = Qwen3CoderToolParser(tokenizer=DummyTokenizer()) - request = ChatCompletionRequest(model='qwen3coder', - messages=[{ - 'role': 'user', - 'content': 'hello' - }, { - 'role': 'assistant', - 'content': 'plain text response' - }, { - 'role': 'tool', - 'content': '', - 'tool_calls': [{ - 'id': 'call_1', - 'type': 'function', - 'function': { - 'name': 'fn', - 'arguments': '{"x": 1}' - } - }] - }]) - - adjusted_request = parser.adjust_request(request) - - assert adjusted_request is request - - -def test_adjust_request_noops_for_dict_arguments(): - parser = Qwen3CoderToolParser(tokenizer=DummyTokenizer()) - request = ChatCompletionRequest(model='qwen3coder', - messages=[{ - 'role': 'assistant', - 'content': '', - 'tool_calls': [{ - 'id': 'call_1', - 'type': 'function', - 'function': { - 'name': 'fn', - 'arguments': { - 'x': 1 - } - } - }] - }]) - - adjusted_request = parser.adjust_request(request) - - assert adjusted_request is request - - -@pytest.mark.parametrize('model_path', ['Qwen/Qwen3.5-35B-A3B']) -def test_adjust_request_renders_qwen_template_from_string_payload(model_path): - chat_template = MODELS.get('hf')(model_path) - parser = Qwen3CoderToolParser(tokenizer=DummyTokenizer()) - request = ChatCompletionRequest(model='qwen3coder', - messages=[{ - 'role': 'user', - 'content': 'What is the weather in Paris?' - }, { - 'role': 'assistant', - 'content': '', - 'tool_calls': [{ - 'id': 'call_1', - 'type': 'function', - 'function': { - 'name': 'get_weather', - 'arguments': '{"city":"Paris","units":"metric"}' - } - }] - }]) - - adjusted_request = parser.adjust_request(request) - prompt = chat_template.messages2prompt(adjusted_request.messages) - - assert adjusted_request is not request - assert adjusted_request.messages[1]['tool_calls'][0]['function']['arguments'] == { - 'city': 'Paris', - 'units': 'metric' - } - assert request.messages[1]['tool_calls'][0]['function']['arguments'] == '{"city":"Paris","units":"metric"}' - assert '' in prompt - assert '\nParis\n' in prompt - assert '\nmetric\n' in prompt