Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,13 @@ def update_params(self, request: Any):
"""Update params."""
self.executor.update_params(request)

def sleep(self, level: int = 1):
async def sleep(self, level: int = 1):
"""Sleep."""
self.executor.sleep(level)
await self.executor.sleep(level)

def wakeup(self, tags: list[str] | None = None):
async def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
self.executor.wakeup(tags)
await self.executor.wakeup(tags)

async def async_loop(self):
engine_loop = None
Expand Down
10 changes: 10 additions & 0 deletions lmdeploy/pytorch/engine/engine_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,16 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'):
seq.append_logits(logits)
return dict()

engine_error_msg = getattr(batched_outputs, 'engine_error_msg', None)
if engine_error_msg:
for msg in running:
if msg.status != MessageStatus.RUNNING:
continue
response_reqs(self.req_manager, msg.resp, ResponseType.INTERNAL_ENGINE_ERROR,
data=dict(token_ids=[]), err_msg=engine_error_msg)
msg.state.finish()
return dict()

new_token_timestamp = batched_outputs.new_token_timestamp
logprobs = batched_outputs.logprobs

Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def sleep(self, level: int = 1):
"""Sleep."""
raise NotImplementedError('Not Implemented.')

def wakeup(self, tags: list[str] | None = None):
async def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
raise NotImplementedError('Not Implemented.')

Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def sleep(self, level: int = 1):
"""Sleep."""
await self.model_agent.sleep(level)

def wakeup(self, tags: list[str] | None = None):
async def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
self.model_agent.wakeup(tags)

Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/pytorch/engine/executor/mp_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,14 @@ def warmup(self):
"""Build cache engine."""
self.collective_rpc('warmup')

async def sleep(self, level: int = 1):
"""Sleep."""
await self.collective_rpc_async('sleep', args=(level, ))

async def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
await self.collective_rpc_async('wakeup', args=(tags, ))

async def _prefetch_outputs(self):
while True:
out = (await self.collective_rpc_async('get_outputs', receiver_mask=1, return_mask=1))[0]
Expand Down
10 changes: 8 additions & 2 deletions lmdeploy/pytorch/engine/executor/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,18 @@ def warmup(self):
"""Build cache engine."""
self.collective_rpc('warmup')

def sleep(self, level: int = 1):
async def sleep(self, level: int = 1):
"""Sleep."""
await asyncio.to_thread(self._sleep_collective_rpc, level)

def _sleep_collective_rpc(self, level: int):
self.collective_rpc('sleep', (level, ))

def wakeup(self, tags: list[str] | None = None):
async def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
await asyncio.to_thread(self._wakeup_collective_rpc, tags)

def _wakeup_collective_rpc(self, tags: list[str] | None):
if tags is None or 'kv_cache' in tags:
self.update_configs()
self.collective_rpc('wakeup', (tags, ))
Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/pytorch/engine/executor/uni_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ async def get_output_async(self, dp_rank: int = 0):
assert dp_rank == 0
return await self.model_agent.get_output_async()

async def sleep(self, level: int = 1):
"""Sleep."""
await self.model_agent.sleep(level)

async def wakeup(self, tags: list[str] | None = None):
"""Wakeup on the event-loop thread (CUDA-safe; may block the loop)."""
self.model_agent.wakeup(tags)

def get_input_processor(self):
"""Get input processor."""
return self.model_agent.get_input_processor()
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/model_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from lmdeploy.pytorch.devices import DeviceContext, get_device_manager
from lmdeploy.pytorch.distributed import DistContext, get_dist_manager

from .agent import BaseModelAgent, BatchedOutputs # noqa: F401
from .agent import BaseModelAgent, BatchedOutputs, CacheNotReadyError # noqa: F401


def build_model_agent(
Expand Down
54 changes: 52 additions & 2 deletions lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
logger = get_logger('lmdeploy')


class CacheNotReadyError(RuntimeError):
"""Raised when a forward runs while KV/state cache engines are missing."""

pass


@dataclass
class SleepWakeupState:
to_sleep: asyncio.Event = field(default_factory=asyncio.Event)
Expand Down Expand Up @@ -82,6 +88,7 @@ class BatchedOutputs:
new_token_timestamp: int = 0
extra_outputs: ExtraOutputs | None = None
all_routed_experts: torch.Tensor | None = None
engine_error_msg: str | None = None

def to_cpu(self):
"""To cpu."""
Expand Down Expand Up @@ -128,11 +135,15 @@ def msg_with_rank(rank: int, msg: str):
return f'rank[{rank}] - {msg}'


def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict):
def cache_swapping(cache_engine: CacheEngine | None, swap_in_map: dict, swap_out_map: dict):
"""Perform cache swapping."""
issued_cache_op = False
swap_in_map = swap_in_map or dict()
swap_out_map = swap_out_map or dict()
if cache_engine is None and (len(swap_in_map) > 0 or len(swap_out_map) > 0):
raise CacheNotReadyError(
'KV cache is not available; cannot swap blocks. '
"Restore cache via wakeup with the 'kv_cache' tag before inference.")
if len(swap_in_map) > 0:
cache_engine.swap_in(swap_in_map)
issued_cache_op = True
Expand Down Expand Up @@ -568,6 +579,27 @@ def _push_output(self, output: BatchedOutputs):
event.record()
self._out_que.put_nowait((output, event))

def _batched_outputs_for_cache_error(self, forward_inputs: dict, err_msg: str) -> BatchedOutputs:
"""Build a minimal batch for get_output_async pairing."""
inputs = forward_inputs.get('inputs')
delta = forward_inputs.get('delta')
device = self.device
if inputs is not None:
batch_size = int(inputs.seq_length.size(0))
elif delta is not None:
batch_size = int(delta.block_offsets.size(0))
else:
batch_size = 1
next_token_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=device)
stopped = torch.ones(batch_size, dtype=torch.bool, device=device)
stop_pos = torch.zeros(batch_size, dtype=torch.long, device=device)
return BatchedOutputs(
next_token_ids=next_token_ids,
stopped=stopped,
stop_pos=stop_pos,
engine_error_msg=err_msg,
)

@contextmanager
def _broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ExtraInputs, enable: bool = True):
if not enable:
Expand Down Expand Up @@ -644,6 +676,10 @@ def _get_inputs_from_delta(
sampling_inputs: SamplingInputs,
):
"""Get inputs from delta."""
if self.step_inputs.model_inputs is None:
raise CacheNotReadyError(
'Decode step has no cached ModelInputs (e.g. after a KV-cache error or reset). '
"Call wakeup with the 'kv_cache' tag before continuing inference.")
self.step_inputs.update_delta(delta, self)
inputs = self.step_inputs.model_inputs
extra_inputs = self.step_inputs.extra_inputs
Expand All @@ -661,6 +697,10 @@ def _prepare_inputs_prefill(
if delta is not None:
# update decoding inputs with delta
# for second round chat
if self.step_inputs.model_inputs is None:
raise CacheNotReadyError(
'Prefill with delta requires prior ModelInputs (e.g. after a KV-cache error). '
"Call wakeup with the 'kv_cache' tag before continuing inference.")
self.step_inputs.update_delta(delta, self)

if inputs.is_first_chunk:
Expand Down Expand Up @@ -836,6 +876,11 @@ def __update_inputs(
await asyncio.sleep(0.01)
return

if self.cache_engine is None or self.state_cache_engine is None:
raise CacheNotReadyError(
'KV or state cache engine is not built (e.g. after sleep or partial wakeup). '
"Call wakeup with the 'kv_cache' tag before running inference.")

# swap caches
cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map)

Expand Down Expand Up @@ -936,7 +981,12 @@ async def _async_loop_background(self, forward_event: asyncio.Event = None):
while True:
forward_inputs = await input_maker.get()

await self._async_step(**forward_inputs, )
try:
await self._async_step(**forward_inputs, )
except CacheNotReadyError as err:
logger.warning('Forward skipped: %s', err)
self.step_inputs = StepInputs()
self._push_output(self._batched_outputs_for_cache_error(forward_inputs, str(err)))
if forward_event is not None:
forward_event.set()

Expand Down
8 changes: 4 additions & 4 deletions lmdeploy/pytorch/engine/mp_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def end_session(self, session_id: int):
"""End session."""
return self._collective_rpc('end_session', session_id)

def sleep(self, level: int):
async def sleep(self, level: int):
"""sleep."""
return self._collective_rpc('sleep', level)
return await self._collective_rpc_async('sleep', level)

def wakeup(self, tags: list[str] | None = None):
async def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
return self._collective_rpc('wakeup', tags)
return await self._collective_rpc_async('wakeup', tags)

def update_params(self, request: Any):
"""Update params."""
Expand Down
8 changes: 4 additions & 4 deletions lmdeploy/pytorch/engine/mp_engine/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):
"""
return self.engine.p2p_drop_connect(drop_conn_request)

def sleep(self, level: int = 1):
async def sleep(self, level: int = 1):
"""sleep."""
return self.engine.sleep(level)
return await self.engine.sleep(level)

def wakeup(self, tags: list[str] | None = None):
async def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
return self.engine.wakeup(tags)
return await self.engine.wakeup(tags)

def update_params(self, request: Any):
"""Update params."""
Expand Down
50 changes: 40 additions & 10 deletions lmdeploy/serve/core/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class GenOut:
history_token_len: int
input_token_len: int
generate_token_len: int
finish_reason: Literal['stop', 'length', 'error'] | None = None
finish_reason: Literal['stop', 'length', 'error', 'abort'] | None = None
token_ids: list[int] | None = None
logprobs: list[dict[int, float]] | None = None
logits: Any = None
Expand Down Expand Up @@ -201,6 +201,23 @@ def _build_stat_loggers(self):
# set stats loggers of metrics processor
metrics_processor.stat_loggers = self.stat_loggers

def _if_session_stale(self, session: Session,
input_token_len: int) -> GenOut | None:
"""If ``session.epoch`` was stamped by api_server and
``stop_all_session`` ran since then (the engine epoch changed), drop
the session."""
epoch = session.epoch
if epoch is None or epoch == self.epoch:
return None
logger.info(
f'[generate] session {session.session_id} dropped (session.epoch={epoch}, epoch={self.epoch})')
return GenOut(response='',
history_token_len=session.step,
input_token_len=input_token_len,
generate_token_len=0,
finish_reason='abort',
token_ids=[])

def get_schedule_metrics(self):
return self.engine.get_schedule_metrics()

Expand All @@ -212,23 +229,25 @@ async def do_log_stats(self):

async def stop_all_session(self):
"""Stop all running sessions."""
logger.info('stop all sessions')
logger.info(f'stop all sessions, epoch {self.epoch} -> {self.epoch + 1}')
self.epoch += 1
await self.session_mgr.async_abort_all()

def sleep(self, level: int = 1):
async def sleep(self, level: int = 1):
"""Sleep the model.

Args:
level (int): The sleep level. Level 1 sleep will offload the model
weights and discard the kv cache. Level 2 sleep will
discard both the model weights and the kv cache.
"""
self.engine.sleep(level)
logger.info(f'[async_engine]sleep, level={level}')
await self.engine.sleep(level)
self.sleeping_tags = {'weights', 'kv_cache'}
self.is_sleeping = True
logger.info('[async_engine] sleep, done')

def wakeup(self, tags: list[str] | None = None):
async def wakeup(self, tags: list[str] | None = None):
"""Wake up the model.

Args:
Expand All @@ -242,7 +261,7 @@ def wakeup(self, tags: list[str] | None = None):
if any(tag not in self.sleeping_tags for tag in tags):
logger.warning(f'some tag in {tags} not in sleeping tags {self.sleeping_tags}')
return
self.engine.wakeup(tags)
await self.engine.wakeup(tags)
# for TM backend, sleep/wakeup will reset gateway, therefore we need to rebuild instances
if self.backend == 'turbomind' and 'kv_cache' in tags:
self.session_mgr.build_request_handle_pool(self.engine, self.backend_config.max_batch_size)
Expand Down Expand Up @@ -339,7 +358,8 @@ async def generate(
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
"""
epoch = self.epoch
metrics_processor.increase_total_requests()

if (messages is not None) ^ (input_ids is None):
raise ValueError('You must specify exactly one of messages or input_ids')
if isinstance(session_id, Session):
Expand Down Expand Up @@ -386,6 +406,7 @@ async def generate(

if gen_config.max_new_tokens == 0:
logger.info(f'run out of tokens. session={session_id}.')
metrics_processor.increase_failed_requests('error')
yield GenOut(response='',
history_token_len=session.step,
input_token_len=len(input_ids),
Expand All @@ -400,6 +421,7 @@ async def generate(
or gen_config.output_logits == 'all'):
errmsg = ('lmdeploy does not support outputting all token\'s logits or last_hidden_state '
'when prefix caching is ON')
metrics_processor.increase_failed_requests('error')
yield GenOut(response=errmsg,
history_token_len=session.step,
input_token_len=len(input_ids),
Expand All @@ -421,10 +443,18 @@ def is_error(status):
if not gen_config.ignore_eos:
stop_ids = gen_config.stop_token_ids or []

metrics_processor.increase_total_requests()

stale = self._if_session_stale(session, len(prompt_input['input_ids']))
if stale is not None:
metrics_processor.increase_failed_requests('abort')
yield stale
if sequence_end:
self.session_mgr.remove(session)
return
async with session.request_handle() as handle:
if epoch != self.epoch:
logger.info(f'[generate] session {session_id} got aborted before starting inference')
if session.epoch is not None and session.epoch != self.epoch:
logger.info(f'[generate] session {session_id} got aborted before starting inference, '
f'session.epoch={session.epoch}, epoch={self.epoch}')
metrics_processor.increase_failed_requests('abort')
yield GenOut(response='',
history_token_len=0,
Expand Down
Loading
Loading