diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 5e84ee4221..1c10a7c95c 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -192,8 +192,8 @@ def _create_completion_logprobs(tokenizer: Tokenizer, Args: tokenizer (Tokenizer): tokenizer. - token_ids (list[int]): output token ids. - logprobs (list[dict[int, float]]): the top logprobs for each output + token_ids (List[int]): output token ids. + logprobs (List[Dict[int, float]]): the top logprobs for each output position. skip_special_tokens (bool): Whether or not to remove special tokens in the decoding. Default to be True. @@ -246,8 +246,8 @@ def _create_chat_completion_logprobs(tokenizer: Tokenizer, Args: tokenizer (Tokenizer): tokenizer. - token_ids (list[int]): output token ids. - logprobs (list[dict[int, float]]): the top logprobs for each output + token_ids (List[int]): output token ids. + logprobs (List[Dict[int, float]]): the top logprobs for each output position. Returns: ChoiceLogprobs: logprob result. @@ -342,7 +342,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque probable tokens with probabilities that add up to top_p or higher are kept for generation. - **n** (int): How many chat completion choices to generate for each input - message. **Only support one here**. + message. Default to 1. - **stream**: whether to stream the results or not. Default to false. - **stream_options**: Options for streaming response. Only set this when you set stream: true. @@ -351,7 +351,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque Deprecated: Use max_completion_tokens instead. - **repetition_penalty** (float): The parameter for repetition penalty. 1.0 means no penalty - - **stop** (str | list[str] | None): To stop generating further + - **stop** (str | List[str] | None): To stop generating further tokens. Only accept stop words that's encoded to one token idex. - **response_format** (dict | None): To generate response according to given schema. Examples: @@ -411,7 +411,12 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque return error_check_ret if VariableInterface.tool_parser is not None: request = VariableInterface.tool_parser.adjust_request(request) - session = VariableInterface.get_session(request.session_id) + _n = request.n or 1 + if _n == 1: + sessions = [VariableInterface.get_session(request.session_id)] + else: + sessions = [VariableInterface.get_session(-1) for _ in range(_n)] + session = sessions[0] json_request = await raw_request.json() migration_request = json_request.pop('migration_request', None) @@ -426,9 +431,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque adapter_name = model_name # got a adapter name request_id = str(session.session_id) created_time = int(time.time()) - gpt_oss_parser = None - if VariableInterface.async_engine.arch == 'GptOssForCausalLM': - gpt_oss_parser = GptOssChatParser() + is_gpt_oss = VariableInterface.async_engine.arch == 'GptOssForCausalLM' if isinstance(request.stop, str): request.stop = [request.stop] @@ -479,7 +482,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque gen_config.skip_special_tokens = False # internlm2 only uses contents inside function regardless of 'type' if not isinstance(request.tool_choice, str): - if gpt_oss_parser: + if is_gpt_oss: tools = [ item.model_dump() for item in request.tools if item.function.name == request.tool_choice.function.name @@ -490,7 +493,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque if item.function.name == request.tool_choice.function.name ] else: - if gpt_oss_parser: + if is_gpt_oss: tools = [item.model_dump() for item in request.tools] else: tools = [item.function.model_dump() for item in request.tools] @@ -505,20 +508,28 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque else: logger.warning('`enable_thinking` in `chat_template_kwargs` will override the value in request.') enable_thinking = chat_template_kwargs.get('enable_thinking', None) - result_generator = VariableInterface.async_engine.generate( - request.messages, - session, - gen_config=gen_config, - tools=tools, - reasoning_effort=request.reasoning_effort, - stream_response=True, # always use stream to enable batching - sequence_start=True, - sequence_end=True, - do_preprocess=do_preprocess, - adapter_name=adapter_name, - chat_template_kwargs=chat_template_kwargs or None, - media_io_kwargs=request.media_io_kwargs, - mm_processor_kwargs=request.mm_processor_kwargs) + generators = [] + for _i, _sess in enumerate(sessions): + if _i > 0 and random_seed is not None: + _cfg = copy.copy(gen_config) + _cfg.random_seed = random_seed + _i + else: + _cfg = gen_config + generators.append( + VariableInterface.async_engine.generate( + request.messages, + _sess, + gen_config=_cfg, + tools=tools, + reasoning_effort=request.reasoning_effort, + stream_response=True, # always use stream to enable batching + sequence_start=True, + sequence_end=True, + do_preprocess=do_preprocess, + adapter_name=adapter_name, + chat_template_kwargs=chat_template_kwargs or None, + media_io_kwargs=request.media_io_kwargs, + mm_processor_kwargs=request.mm_processor_kwargs)) def create_stream_response_json(index: int, delta_message: DeltaMessage, @@ -536,86 +547,89 @@ def create_stream_response_json(index: int, choices=[choice_data], usage=usage, ) - response_json = response.model_dump_json() + response_json = response.model_dump() return response_json async def completion_stream_generator() -> AsyncGenerator[str, None]: - previous_text = '' - current_text = '' - previous_token_ids = [] - current_token_ids = [] - delta_token_ids = [] has_parser = VariableInterface.tool_parser is not None or VariableInterface.reasoning_parser is not None - streaming_tools = False - async for res in result_generator: - logprobs, usage = None, None - if gen_logprobs and res.logprobs: - logprobs = _create_chat_completion_logprobs(VariableInterface.async_engine.tokenizer, res.token_ids, - res.logprobs) - # Only stream chunk `usage` in the final chunk according to OpenAI API spec - if (res.finish_reason and request.stream_options and request.stream_options.include_usage): - total_tokens = sum([res.input_token_len, res.generate_token_len]) - usage = UsageInfo( - prompt_tokens=res.input_token_len, - completion_tokens=res.generate_token_len, - total_tokens=total_tokens, - ) - - delta_token_ids = res.token_ids if res.token_ids is not None else [] - if gpt_oss_parser: - delta_message = gpt_oss_parser.parse_streaming(res.token_ids) - if res.finish_reason == 'stop' and len(delta_message.tool_calls) > 0: - res.finish_reason = 'tool_calls' - else: - delta_message = DeltaMessage(role='assistant', content=res.response) - if has_parser: - current_text = current_text + res.response - current_token_ids = current_token_ids + delta_token_ids - if request.tool_choice != 'none' and VariableInterface.tool_parser is not None: - if res.finish_reason == 'stop' and streaming_tools is True: + for _gen_idx, _gen in enumerate(generators): + previous_text = '' + current_text = '' + previous_token_ids = [] + current_token_ids = [] + delta_token_ids = [] + streaming_tools = False + # each generator needs its own stateful streaming parser instance + _gpt_oss_parser = GptOssChatParser() if is_gpt_oss else None + async for res in _gen: + logprobs, usage = None, None + if gen_logprobs and res.logprobs: + logprobs = _create_chat_completion_logprobs(VariableInterface.async_engine.tokenizer, res.token_ids, + res.logprobs) + # Only stream chunk `usage` in the final chunk according to OpenAI API spec + if (res.finish_reason and request.stream_options and request.stream_options.include_usage): + total_tokens = sum([res.input_token_len, res.generate_token_len]) + usage = UsageInfo( + prompt_tokens=res.input_token_len, + completion_tokens=res.generate_token_len, + total_tokens=total_tokens, + ) + + delta_token_ids = res.token_ids if res.token_ids is not None else [] + if _gpt_oss_parser: + delta_message = _gpt_oss_parser.parse_streaming(res.token_ids) + if res.finish_reason == 'stop' and len(delta_message.tool_calls) > 0: res.finish_reason = 'tool_calls' - tool_delta = VariableInterface.tool_parser.extract_tool_calls_streaming( - previous_text=previous_text, - current_text=current_text, - delta_text=delta_message.content, - previous_token_ids=previous_token_ids, - current_token_ids=current_token_ids, - delta_token_ids=delta_token_ids, - request=request) - if tool_delta is not None: - delta_message.tool_calls = tool_delta.tool_calls - delta_message.content = tool_delta.content - if isinstance(tool_delta.tool_calls, list) and len(tool_delta.tool_calls): - streaming_tools = True - elif (request.tool_choice != 'none' and request.tools is not None - and VariableInterface.tool_parser is None): - logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.') - if VariableInterface.reasoning_parser is not None and enable_thinking is not False: - reasoning_delta = VariableInterface.reasoning_parser.extract_reasoning_content_streaming( - previous_text=previous_text, - current_text=current_text, - delta_text=delta_message.content or '', - previous_token_ids=previous_token_ids, - current_token_ids=current_token_ids, - delta_token_ids=delta_token_ids) - if reasoning_delta is not None: - delta_message.reasoning_content = reasoning_delta.reasoning_content - delta_message.content = reasoning_delta.content - if has_parser: - previous_text = current_text - previous_token_ids = current_token_ids - if request.return_token_ids: - delta_message.gen_tokens = delta_token_ids - response_json = create_stream_response_json(index=0, - delta_message=delta_message, - finish_reason=res.finish_reason, - logprobs=logprobs, - usage=usage) - if res.cache_block_ids is not None: - response_json['cache_block_ids'] = res.cache_block_ids - response_json['remote_token_ids'] = res.token_ids - yield f'data: {response_json}\n\n' + else: + delta_message = DeltaMessage(role='assistant', content=res.response) + if has_parser: + current_text = current_text + res.response + current_token_ids = current_token_ids + delta_token_ids + if request.tool_choice != 'none' and VariableInterface.tool_parser is not None: + if res.finish_reason == 'stop' and streaming_tools is True: + res.finish_reason = 'tool_calls' + tool_delta = VariableInterface.tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_message.content, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request) + if tool_delta is not None: + delta_message.tool_calls = tool_delta.tool_calls + delta_message.content = tool_delta.content + if isinstance(tool_delta.tool_calls, list) and len(tool_delta.tool_calls): + streaming_tools = True + elif (request.tool_choice != 'none' and request.tools is not None + and VariableInterface.tool_parser is None): + logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.') + if VariableInterface.reasoning_parser is not None and enable_thinking is not False: + reasoning_delta = VariableInterface.reasoning_parser.extract_reasoning_content_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_message.content or '', + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids) + if reasoning_delta is not None: + delta_message.reasoning_content = reasoning_delta.reasoning_content + delta_message.content = reasoning_delta.content + if has_parser: + previous_text = current_text + previous_token_ids = current_token_ids + if request.return_token_ids: + delta_message.gen_tokens = delta_token_ids + response_json = create_stream_response_json(index=_gen_idx, + delta_message=delta_message, + finish_reason=res.finish_reason, + logprobs=logprobs, + usage=usage) + if res.cache_block_ids is not None: + response_json['cache_block_ids'] = res.cache_block_ids + response_json['remote_token_ids'] = res.token_ids + yield f'data: {json.dumps(response_json)}\n\n' yield 'data: [DONE]\n\n' # Streaming response @@ -623,81 +637,95 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream') # Non-streaming response - final_logprobs = [] - final_token_ids = [] - final_res = None - text = '' - cache_block_ids = [] - remote_token_ids = [] - async for res in result_generator: - if await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await session.async_abort() - return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') - final_res = res - text += res.response - if res.token_ids: - final_token_ids.extend(res.token_ids) - if res.logprobs: - final_logprobs.extend(res.logprobs) - cache_block_ids.append(res.cache_block_ids) - remote_token_ids.append(res.token_ids) - - if gpt_oss_parser: - message = gpt_oss_parser.parse_full(final_token_ids) - if final_res.finish_reason == 'stop' and len(message.tool_calls) > 0: - final_res.finish_reason = 'tool_calls' - else: - tool_calls = None - reasoning_content = None - if request.tool_choice != 'none' and VariableInterface.tool_parser is not None: - try: - tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request) - text, tool_calls = tool_call_info.content, tool_call_info.tool_calls - if isinstance(tool_calls, list) and len(tool_calls): - if final_res.finish_reason == 'stop': - final_res.finish_reason = 'tool_calls' - - except Exception as e: - logger.error(f'Failed to parse {text}. Exception: {e}.') - return create_error_response(HTTPStatus.BAD_REQUEST, 'Failed to parse fc related info to json format!') - elif request.tool_choice != 'none' and request.tools is not None and VariableInterface.tool_parser is None: - logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.') - - if VariableInterface.reasoning_parser is not None and enable_thinking is not False: - reasoning_content, text = VariableInterface.reasoning_parser.extract_reasoning_content(text, request) - - message = ChatMessage(role='assistant', - content=text, - tool_calls=tool_calls, - reasoning_content=reasoning_content) - - logprobs = None - if gen_logprobs and len(final_logprobs): - logprobs = _create_chat_completion_logprobs(VariableInterface.async_engine.tokenizer, final_token_ids, - final_logprobs) - - assert final_res is not None - choices = [] - if request.return_token_ids: - message.gen_tokens = final_token_ids - choice_data = ChatCompletionResponseChoice( - index=0, - message=message, - logprobs=logprobs, - finish_reason=final_res.finish_reason, - ) - choices.append(choice_data) + choices = [None] * _n + _prompt_tokens = 0 + _completion_tokens = 0 + _cache_block_ids_list = [None] * _n + _remote_token_ids_list = [None] * _n + + async def _collect_chat_response(_i, _gen, _sess): + nonlocal _prompt_tokens, _completion_tokens + final_logprobs_i = [] + final_token_ids_i = [] + final_res_i = None + text_i = '' + cache_block_ids_i = [] + remote_token_ids_i = [] + async for res in _gen: + if await raw_request.is_disconnected(): + await _sess.async_abort() + return False + final_res_i = res + text_i += res.response + if res.token_ids: + final_token_ids_i.extend(res.token_ids) + if res.logprobs: + final_logprobs_i.extend(res.logprobs) + cache_block_ids_i.append(res.cache_block_ids) + remote_token_ids_i.append(res.token_ids) + + if is_gpt_oss: + _parser_i = GptOssChatParser() + message_i = _parser_i.parse_full(final_token_ids_i) + if final_res_i.finish_reason == 'stop' and len(message_i.tool_calls) > 0: + final_res_i.finish_reason = 'tool_calls' + else: + tool_calls_i = None + reasoning_content_i = None + if request.tool_choice != 'none' and VariableInterface.tool_parser is not None: + try: + tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text_i, request=request) + text_i, tool_calls_i = tool_call_info.content, tool_call_info.tool_calls + if isinstance(tool_calls_i, list) and len(tool_calls_i): + if final_res_i.finish_reason == 'stop': + final_res_i.finish_reason = 'tool_calls' + except Exception as e: + logger.error(f'Failed to parse {text_i}. Exception: {e}.') + raise + elif request.tool_choice != 'none' and request.tools is not None and VariableInterface.tool_parser is None: + logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.') + + if VariableInterface.reasoning_parser is not None and enable_thinking is not False: + reasoning_content_i, text_i = VariableInterface.reasoning_parser.extract_reasoning_content( + text_i, request) + + message_i = ChatMessage(role='assistant', + content=text_i, + tool_calls=tool_calls_i, + reasoning_content=reasoning_content_i) + + logprobs_i = None + if gen_logprobs and len(final_logprobs_i): + logprobs_i = _create_chat_completion_logprobs(VariableInterface.async_engine.tokenizer, final_token_ids_i, + final_logprobs_i) + + assert final_res_i is not None + if request.return_token_ids: + message_i.gen_tokens = final_token_ids_i + choices[_i] = ChatCompletionResponseChoice( + index=_i, + message=message_i, + logprobs=logprobs_i, + finish_reason=final_res_i.finish_reason, + ) + _cache_block_ids_list[_i] = cache_block_ids_i + _remote_token_ids_list[_i] = remote_token_ids_i + # prompt_tokens is identical across all outputs (same input); completion_tokens is summed + _prompt_tokens = final_res_i.input_token_len + _completion_tokens += final_res_i.generate_token_len + return True - if with_cache: - cache_block_ids = cache_block_ids[0] - remote_token_ids = [remote_token_ids[0][-1]] + try: + results = await asyncio.gather(*[_collect_chat_response(_i, generators[_i], sessions[_i]) for _i in range(_n)]) + except Exception as e: + return create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, str(e)) + if not all(results): + return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') - total_tokens = sum([final_res.input_token_len, final_res.generate_token_len]) usage = UsageInfo( - prompt_tokens=final_res.input_token_len, - completion_tokens=final_res.generate_token_len, - total_tokens=total_tokens, + prompt_tokens=_prompt_tokens, + completion_tokens=_completion_tokens, + total_tokens=_prompt_tokens + _completion_tokens, ) response = ChatCompletionResponse( id=request_id, @@ -707,9 +735,11 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: usage=usage, ).model_dump() - if with_cache: - response['cache_block_ids'] = cache_block_ids - response['remote_token_ids'] = remote_token_ids + if with_cache and _n == 1: + _cb = _cache_block_ids_list[0] + _rt = _remote_token_ids_list[0] + response['cache_block_ids'] = _cb[0] if _cb else None + response['remote_token_ids'] = [_rt[0][-1]] if _rt and _rt[0] else None return response @@ -734,7 +764,7 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None probable tokens with probabilities that add up to top_p or higher are kept for generation. - **n** (int): How many chat completion choices to generate for each input - message. **Only support one here**. + message. Default to 1. - **stream**: whether to stream the results or not. Default to false. - **stream_options**: Options for streaming response. Only set this when you set stream: true. @@ -783,17 +813,13 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None adapter_name = model_name # got a adapter name request_id = str(request.session_id) created_time = int(time.time()) - sessions = [] if isinstance(request.prompt, str): request.prompt = [request.prompt] - sessions.append(VariableInterface.get_session(request.session_id)) - elif isinstance(request.prompt, list): - for i in range(len(request.prompt)): - sessions.append(VariableInterface.get_session(i + 1)) if isinstance(request.stop, str): request.stop = [request.stop] random_seed = request.seed if request.seed else None max_new_tokens = (request.max_completion_tokens if request.max_completion_tokens else request.max_tokens) + _n = request.n or 1 gen_config = GenerationConfig( max_new_tokens=max_new_tokens, @@ -813,18 +839,34 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None with_cache=with_cache, preserve_cache=preserve_cache, ) + # Build sessions and generators: for each prompt, create n outputs + # Total generators = len(prompts) * n, indexed as prompt_idx * n + n_idx + sessions = [] generators = [] - for prompt, session in zip(request.prompt, sessions): - result_generator = VariableInterface.async_engine.generate( - prompt, - session, - gen_config=gen_config, - stream_response=True, # always use stream to enable batching - sequence_start=True, - sequence_end=True, - do_preprocess=False, - adapter_name=adapter_name) - generators.append(result_generator) + _first_session = True + for _p_idx, _prompt in enumerate(request.prompt): + for _n_idx in range(_n): + if _first_session: + _sess = VariableInterface.get_session(request.session_id) + _first_session = False + else: + _sess = VariableInterface.get_session(-1) + sessions.append(_sess) + if _n_idx > 0 and random_seed is not None: + _cfg = copy.copy(gen_config) + _cfg.random_seed = random_seed + _n_idx + else: + _cfg = gen_config + generators.append( + VariableInterface.async_engine.generate( + _prompt, + _sess, + gen_config=_cfg, + stream_response=True, # always use stream to enable batching + sequence_start=True, + sequence_end=True, + do_preprocess=False, + adapter_name=adapter_name)) def create_stream_response_json(index: int, text: str, @@ -848,8 +890,7 @@ def create_stream_response_json(index: int, return response_json async def completion_stream_generator() -> AsyncGenerator[str, None]: - # First chunk with role - for generator in generators: + for _gen_idx, generator in enumerate(generators): offset = 0 all_token_ids = [] state = DetokenizeState() @@ -873,7 +914,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: gen_tokens = None if request.return_token_ids: gen_tokens = res.token_ids or [] - response_json = create_stream_response_json(index=0, + response_json = create_stream_response_json(index=_gen_idx, text=res.response, gen_tokens=gen_tokens, finish_reason=res.finish_reason, @@ -1450,13 +1491,13 @@ def serve(model_path: str, server_name (str): host ip for serving server_port (int): server port tp (int): tensor parallel - allow_origins (list[str]): a list of allowed origins for CORS + allow_origins (List[str]): a list of allowed origins for CORS allow_credentials (bool): whether to allow credentials for CORS - allow_methods (list[str]): a list of allowed HTTP methods for CORS - allow_headers (list[str]): a list of allowed HTTP headers for CORS + allow_methods (List[str]): a list of allowed HTTP methods for CORS + allow_headers (List[str]): a list of allowed HTTP headers for CORS log_level(str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG] - api_keys (list[str] | str | None): Optional list of API keys. Accepts + api_keys (List[str] | str | None): Optional list of API keys. Accepts string type as a single api_key. Default to None, which means no api key applied. ssl (bool): Enable SSL. Requires OS Environment variables diff --git a/requirements/test.txt b/requirements/test.txt index 6061aaafde..1542c1db05 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -6,6 +6,7 @@ nvidia-ml-py opencv-python-headless pytest pytest-assume +pytest-asyncio pytest-cov pytest-order pytest-rerunfailures diff --git a/tests/test_lmdeploy/test_n_parameter.py b/tests/test_lmdeploy/test_n_parameter.py new file mode 100644 index 0000000000..60119cc38b --- /dev/null +++ b/tests/test_lmdeploy/test_n_parameter.py @@ -0,0 +1,335 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Unit tests for the `n` parameter in /v1/chat/completions and +/v1/completions. + +These tests mock the async engine so they run without a GPU or real model. +""" +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest +from lmdeploy.serve.openai.serving_chat_completion import check_request as chat_check_request +from lmdeploy.serve.openai.serving_completion import check_request as completion_check_request + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_gen_out(response='hello', finish_reason=None, input_token_len=5, generate_token_len=3): + out = MagicMock() + out.response = response + out.finish_reason = finish_reason + out.token_ids = [1, 2, 3] + out.logprobs = None + out.input_token_len = input_token_len + out.generate_token_len = generate_token_len + out.cache_block_ids = None + return out + + +async def _async_gen(*items): + """Async generator yielding items.""" + for item in items: + yield item + + +def _make_session(session_id=42): + sess = MagicMock() + sess.session_id = session_id + sess.async_abort = AsyncMock() + return sess + + +# --------------------------------------------------------------------------- +# Protocol-level validation +# --------------------------------------------------------------------------- + + +class TestNParameterValidation: + + def _make_server_context(self, session_occupied=False): + ctx = MagicMock() + ctx.get_engine_config.return_value = MagicMock(spec=[]) # no logprobs_mode attr + mgr = MagicMock() + mgr.has.return_value = session_occupied + ctx.get_session_manager.return_value = mgr + return ctx + + def test_chat_n_default_is_1(self): + req = ChatCompletionRequest(model='m', messages='hi') + assert req.n == 1 + + def test_completion_n_default_is_1(self): + req = CompletionRequest(model='m', prompt='hi') + assert req.n == 1 + + def test_chat_n_valid_values(self): + ctx = self._make_server_context() + for n in [1, 2, 5]: + req = ChatCompletionRequest(model='m', messages='hi', n=n) + assert chat_check_request(req, ctx) == '' + + def test_completion_n_valid_values(self): + ctx = self._make_server_context() + for n in [1, 2, 5]: + req = CompletionRequest(model='m', prompt='hi', n=n) + assert completion_check_request(req, ctx) == '' + + def test_chat_n_zero_rejected(self): + ctx = self._make_server_context() + req = ChatCompletionRequest(model='m', messages='hi', n=0) + assert chat_check_request(req, ctx) != '' + + def test_completion_n_zero_rejected(self): + ctx = self._make_server_context() + req = CompletionRequest(model='m', prompt='hi', n=0) + assert completion_check_request(req, ctx) != '' + + def test_chat_n_negative_rejected(self): + ctx = self._make_server_context() + req = ChatCompletionRequest(model='m', messages='hi', n=-1) + assert chat_check_request(req, ctx) != '' + + def test_completion_n_negative_rejected(self): + ctx = self._make_server_context() + req = CompletionRequest(model='m', prompt='hi', n=-1) + assert completion_check_request(req, ctx) != '' + + +# --------------------------------------------------------------------------- +# API handler tests (mocking VariableInterface and raw_request) +# --------------------------------------------------------------------------- + + +def _make_raw_request(disconnected=False): + raw = MagicMock() + raw.json = AsyncMock(return_value={}) + raw.is_disconnected = AsyncMock(return_value=disconnected) + return raw + + +def _setup_variable_interface(mock_vi, n_sessions=1, gen_outputs=None): + """Configure the mocked VariableInterface. + + gen_outputs: list of lists – one list of GenOut per generator call. + """ + engine = MagicMock() + engine.model_name = 'test-model' + engine.arch = 'LlamaForCausalLM' + engine.tokenizer = MagicMock() + + sessions = [_make_session(i + 10) for i in range(max(n_sessions, 1))] + _session_iter = iter(sessions) + + def _get_session(sid): + try: + return next(_session_iter) + except StopIteration: + return _make_session(99) + + mock_vi.get_session.side_effect = _get_session + mock_vi.async_engine = engine + mock_vi.tool_parser = None + mock_vi.reasoning_parser = None + mock_vi.allow_terminate_by_client = False + mock_vi.enable_abort_handling = False + + # Each call to engine.generate returns a different async generator + gen_outputs = gen_outputs or [[_make_gen_out('hello', 'stop')]] + + _gen_iter = iter(gen_outputs) + + def _generate(*args, **kwargs): + items = next(_gen_iter, [_make_gen_out('hello', 'stop')]) + return _async_gen(*items) + + engine.generate.side_effect = _generate + return sessions + + +class TestChatCompletionsN: + + @pytest.mark.asyncio + async def test_n1_returns_one_choice(self): + from lmdeploy.serve.openai import api_server + + request = ChatCompletionRequest(model='test-model', messages=[{'role': 'user', 'content': 'hi'}], n=1) + raw_request = _make_raw_request() + + with patch.object(api_server, 'VariableInterface') as mock_vi, \ + patch.object(api_server, 'check_request', return_value=None): + _setup_variable_interface(mock_vi, n_sessions=1, gen_outputs=[[_make_gen_out('ans1', 'stop')]]) + response = await api_server.chat_completions_v1(request, raw_request) + + assert isinstance(response, dict) + assert len(response['choices']) == 1 + assert response['choices'][0]['index'] == 0 + assert response['choices'][0]['message']['content'] == 'ans1' + + @pytest.mark.asyncio + async def test_n3_returns_three_choices(self): + from lmdeploy.serve.openai import api_server + + request = ChatCompletionRequest(model='test-model', messages=[{'role': 'user', 'content': 'hi'}], n=3) + raw_request = _make_raw_request() + + outputs = [ + [_make_gen_out('ans0', 'stop')], + [_make_gen_out('ans1', 'stop')], + [_make_gen_out('ans2', 'stop')], + ] + + with patch.object(api_server, 'VariableInterface') as mock_vi, \ + patch.object(api_server, 'check_request', return_value=None): + _setup_variable_interface(mock_vi, n_sessions=3, gen_outputs=outputs) + response = await api_server.chat_completions_v1(request, raw_request) + + assert isinstance(response, dict) + choices = response['choices'] + assert len(choices) == 3 + assert [c['index'] for c in choices] == [0, 1, 2] + assert choices[0]['message']['content'] == 'ans0' + assert choices[1]['message']['content'] == 'ans1' + assert choices[2]['message']['content'] == 'ans2' + + @pytest.mark.asyncio + async def test_n3_usage_aggregates_completion_tokens(self): + from lmdeploy.serve.openai import api_server + + request = ChatCompletionRequest(model='test-model', messages=[{'role': 'user', 'content': 'hi'}], n=3) + raw_request = _make_raw_request() + + # Each generator produces 10 completion tokens + outputs = [[_make_gen_out('a', 'stop', input_token_len=5, generate_token_len=10)]] * 3 + + with patch.object(api_server, 'VariableInterface') as mock_vi, \ + patch.object(api_server, 'check_request', return_value=None): + _setup_variable_interface(mock_vi, n_sessions=3, gen_outputs=outputs) + response = await api_server.chat_completions_v1(request, raw_request) + + usage = response['usage'] + assert usage['prompt_tokens'] == 5 # counted once (shared input) + assert usage['completion_tokens'] == 30 # 3 * 10 + assert usage['total_tokens'] == 35 + + @pytest.mark.asyncio + async def test_n1_uses_request_session_id(self): + from lmdeploy.serve.openai import api_server + + request = ChatCompletionRequest(model='test-model', + messages=[{ + 'role': 'user', + 'content': 'hi' + }], + n=1, + session_id=77) + raw_request = _make_raw_request() + + with patch.object(api_server, 'VariableInterface') as mock_vi, \ + patch.object(api_server, 'check_request', return_value=None): + _setup_variable_interface(mock_vi, n_sessions=1) + await api_server.chat_completions_v1(request, raw_request) + + mock_vi.get_session.assert_called_once_with(77) + + @pytest.mark.asyncio + async def test_n3_uses_auto_sessions(self): + from lmdeploy.serve.openai import api_server + + request = ChatCompletionRequest(model='test-model', messages=[{'role': 'user', 'content': 'hi'}], n=3) + raw_request = _make_raw_request() + + with patch.object(api_server, 'VariableInterface') as mock_vi, \ + patch.object(api_server, 'check_request', return_value=None): + _setup_variable_interface(mock_vi, n_sessions=3, gen_outputs=[[_make_gen_out('x', 'stop')]] * 3) + await api_server.chat_completions_v1(request, raw_request) + + # All 3 sessions should be auto-assigned (-1) + calls = mock_vi.get_session.call_args_list + assert len(calls) == 3 + assert all(c.args[0] == -1 for c in calls) + + @pytest.mark.asyncio + async def test_n3_seeds_are_offset(self): + """When seed is set and n>1, generators should use seed, seed+1, + seed+2.""" + from lmdeploy.serve.openai import api_server + + request = ChatCompletionRequest(model='test-model', messages=[{'role': 'user', 'content': 'hi'}], n=3, seed=100) + raw_request = _make_raw_request() + + captured_configs = [] + + def _generate(*args, **kwargs): + captured_configs.append(kwargs.get('gen_config')) + return _async_gen(_make_gen_out('x', 'stop')) + + with patch.object(api_server, 'VariableInterface') as mock_vi, \ + patch.object(api_server, 'check_request', return_value=None): + _setup_variable_interface(mock_vi, n_sessions=3) + mock_vi.async_engine.generate.side_effect = _generate + await api_server.chat_completions_v1(request, raw_request) + + seeds = [cfg.random_seed for cfg in captured_configs] + assert seeds == [100, 101, 102] + + +class TestCompletionsN: + + @pytest.mark.asyncio + async def test_n1_single_prompt_one_choice(self): + from lmdeploy.serve.openai import api_server + + request = CompletionRequest(model='test-model', prompt='hi', n=1) + raw_request = _make_raw_request() + + with patch.object(api_server, 'VariableInterface') as mock_vi, \ + patch.object(api_server, 'check_request', return_value=None): + _setup_variable_interface(mock_vi, n_sessions=1, gen_outputs=[[_make_gen_out('out0', 'stop')]]) + response = await api_server.completions_v1(request, raw_request) + + assert len(response['choices']) == 1 + assert response['choices'][0]['index'] == 0 + + @pytest.mark.asyncio + async def test_n3_single_prompt_three_choices(self): + from lmdeploy.serve.openai import api_server + + request = CompletionRequest(model='test-model', prompt='hi', n=3) + raw_request = _make_raw_request() + + outputs = [ + [_make_gen_out('out0', 'stop')], + [_make_gen_out('out1', 'stop')], + [_make_gen_out('out2', 'stop')], + ] + + with patch.object(api_server, 'VariableInterface') as mock_vi, \ + patch.object(api_server, 'check_request', return_value=None): + _setup_variable_interface(mock_vi, n_sessions=3, gen_outputs=outputs) + response = await api_server.completions_v1(request, raw_request) + + choices = response['choices'] + assert len(choices) == 3 + assert [c['index'] for c in choices] == [0, 1, 2] + + @pytest.mark.asyncio + async def test_n2_two_prompts_four_choices(self): + """2 prompts × n=2 = 4 choices, indexed 0..3.""" + from lmdeploy.serve.openai import api_server + + request = CompletionRequest(model='test-model', prompt=['p0', 'p1'], n=2) + raw_request = _make_raw_request() + + outputs = [[_make_gen_out(f'out{i}', 'stop')] for i in range(4)] + + with patch.object(api_server, 'VariableInterface') as mock_vi, \ + patch.object(api_server, 'check_request', return_value=None): + _setup_variable_interface(mock_vi, n_sessions=4, gen_outputs=outputs) + response = await api_server.completions_v1(request, raw_request) + + choices = response['choices'] + assert len(choices) == 4 + assert [c['index'] for c in choices] == [0, 1, 2, 3]