Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 14 additions & 122 deletions lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from lmdeploy.pytorch.distributed import DistContext, get_dist_manager
from lmdeploy.pytorch.engine.cache_engine import CacheEngine, StateCacheEngine
from lmdeploy.pytorch.engine.guided_process import GuidedDecodingManager
from lmdeploy.pytorch.engine.logits_process import FusedLogitsProcessor, SamplingInputs, SamplingInputsDelta
from lmdeploy.pytorch.engine.logits_process import FusedLogitsProcessor, SamplingInputs
from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta, step_ctx_manager
from lmdeploy.pytorch.models.patch import BuildModelContext, add_adapters, build_patched_model, update_custom_module_map
from lmdeploy.pytorch.spec_decode import build_spec_agent
Expand Down Expand Up @@ -229,92 +229,6 @@ async def async_wait(self, timeout: float = 0.001):
SwapMap = dict[int, int]


@dataclass
class StepInputs:
"""Step inputs."""
model_inputs: ModelInputs = None
extra_inputs: ExtraInputs = None
stopping_criteria: StoppingCriteria = None
sampling_delta: SamplingInputsDelta = None

@record_function('StepInputs.merge')
def merge(
self,
inputs: ModelInputs,
extra_inputs: ExtraInputs,
stopping_criteria: StoppingCriteria,
sampling_delta: SamplingInputsDelta,
next_token_ids: torch.Tensor,
model_metas,
extra_outputs: ExtraOutputs,
model_agent: 'BaseModelAgent',
):
"""Merge prefill inputs."""
inputs, extra_inputs = model_agent.agent_strategy.update_prefill_for_next_step(
inputs,
extra_inputs,
next_token_ids,
model_metas,
extra_outputs,
)
stopping_criteria = stopping_criteria.clone()
sampling_delta = model_agent.sampling_strategy.step_sampling_delta(sampling_delta,
next_token_ids,
extra_inputs=extra_inputs)
if self.model_inputs is None:
self.model_inputs = inputs
self.extra_inputs = extra_inputs
self.stopping_criteria = stopping_criteria
self.sampling_delta = sampling_delta
else:
self.model_inputs = model_agent.inputs_strategy.merge(self.model_inputs, inputs)
self.extra_inputs = self.extra_inputs.merge(extra_inputs)
self.stopping_criteria = self.stopping_criteria.merge(stopping_criteria)
self.sampling_delta = model_agent.sampling_strategy.merge_sampling_delta(
self.sampling_delta, sampling_delta)

def update_delta(
self,
delta: ModelInputsDelta,
model_agent: 'BaseModelAgent',
):
"""Get inputs from delta."""
self.model_inputs = model_agent.inputs_strategy.update_inputs(self.model_inputs, delta)
self.extra_inputs = model_agent.agent_strategy.update_extra_inputs(self.extra_inputs, delta)
self.stopping_criteria = self.stopping_criteria.update(delta)
self.sampling_delta = model_agent.sampling_strategy.update_sampling_delta(self.sampling_delta, delta)

@record_function('StepInputs.step')
def step(
self,
model_inputs: ModelInputs,
extra_inputs: ExtraInputs,
stopping_criteria: StoppingCriteria,
sampling_delta: SamplingInputsDelta,
next_token_ids: torch.Tensor,
model_metas,
extra_outputs: ExtraOutputs,
model_agent: 'BaseModelAgent',
):
"""Update inputs."""
# dp might change is_decoding of decoding inputs
model_inputs.is_decoding = True
(
self.model_inputs,
self.extra_inputs,
) = model_agent.agent_strategy.update_decoding_for_next_step(
model_inputs,
next_token_ids=next_token_ids,
model_metas=model_metas,
extra_inputs=extra_inputs,
extra_outputs=extra_outputs,
)
self.stopping_criteria = stopping_criteria.clone()
self.sampling_delta = model_agent.sampling_strategy.step_sampling_delta(sampling_delta,
next_token_ids,
extra_inputs=extra_inputs)


class BaseModelAgent:
"""Base model agent.

Expand Down Expand Up @@ -421,7 +335,7 @@ def __init__(
self.state: SleepWakeupState = SleepWakeupState()

# decoding inputs
self.step_inputs = StepInputs()
self.step_inputs = self.strategy_factory.build_step_inputs()

# long context
self._prev_chunk_output: dict = None
Expand Down Expand Up @@ -644,7 +558,7 @@ def _get_inputs_from_delta(
sampling_inputs: SamplingInputs,
):
"""Get inputs from delta."""
self.step_inputs.update_delta(delta, self)
self.step_inputs.reindex(delta)
inputs = self.step_inputs.model_inputs
extra_inputs = self.step_inputs.extra_inputs
stopping_criteria = self.step_inputs.stopping_criteria
Expand All @@ -661,7 +575,7 @@ def _prepare_inputs_prefill(
if delta is not None:
# update decoding inputs with delta
# for second round chat
self.step_inputs.update_delta(delta, self)
self.step_inputs.reindex(delta)

if inputs.is_first_chunk:
self._prev_chunk_output = None
Expand Down Expand Up @@ -768,29 +682,6 @@ async def _async_step(
):
"""Asyc forward task."""

@record_function('update_decoding_for_next_step')
def __update_inputs(
inputs,
next_token_ids,
model_metas,
extra_inputs,
extra_outputs,
stopping_criteria,
sampling_delta: SamplingInputsDelta = None,
):
"""Update inputs."""
# dp might change is_decoding of decoding inputs
self.step_inputs.step(
inputs,
extra_inputs,
stopping_criteria,
sampling_delta,
next_token_ids,
model_metas,
extra_outputs,
model_agent=self,
)

dist_ctx = get_dist_manager().current_context()
dist_config = dist_ctx.dist_config
rank = self.rank
Expand Down Expand Up @@ -904,26 +795,27 @@ def __update_inputs(

sampling_delta = sampling_inputs.get_delta()
if need_update_inputs:
__update_inputs(inputs,
next_token_ids,
model_metas,
extra_inputs,
extra_outputs,
stopping_criteria,
sampling_delta=sampling_delta)
self.step_inputs.step_decode(
inputs,
extra_inputs,
stopping_criteria,
sampling_delta,
next_token_ids,
model_metas,
extra_outputs,
)
elif inputs.is_chunk and not inputs.is_last_chunk:
# _prev_chunk_output is used to update model metas
self._prev_chunk_output = output
elif self.cache_config.role != EngineRole.Prefill:
self.step_inputs.merge(
self.step_inputs.merge_prefill(
inputs,
extra_inputs,
stopping_criteria,
sampling_delta,
next_token_ids,
model_metas,
extra_outputs,
model_agent=self,
)

async def _async_loop_background(self, forward_event: asyncio.Event = None):
Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/pytorch/strategies/ar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy
from lmdeploy.pytorch.strategies.base.model_inputs import ModelInputsStrategy
from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy
from lmdeploy.pytorch.strategies.base.step_inputs import StepInputs

from ..base import StrategyFactoryBase

Expand Down Expand Up @@ -52,3 +53,10 @@ def build_engine_strategy(self, cache_config: 'CacheConfig',
def build_sequence_strategy(self) -> SequenceStrategy:
from .sequence import ARSequenceStrategy
return ARSequenceStrategy()

def build_step_inputs(self) -> 'StepInputs':
"""Build step inputs for the decoding loop."""
from .step_inputs import ARStepInputs
pad_token_id = self.model_config.bos_token_id
pad_token_id = 0 if pad_token_id is None else pad_token_id
return ARStepInputs(_pad_token_id=pad_token_id)
52 changes: 0 additions & 52 deletions lmdeploy/pytorch/strategies/ar/model_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any

import torch
import torch.distributed as dist
Expand All @@ -17,37 +16,6 @@
SeqList = list[SchedulerSequence]


def get_model_inputs_next_decoding(inputs: ModelInputs, input_ids: torch.Tensor, max_q_seqlen: int,
model_metas) -> ModelInputs:
"""Next decoding step."""
if input_ids.dim() == 1:
input_ids = input_ids[None, :]
state_offsets = inputs.state_offsets
if state_offsets is not None:
state_offsets = state_offsets.clone()

# mrope
mrope_pos_ids = inputs.mrope_pos_ids
if mrope_pos_ids is not None:
index = inputs.seq_length.cumsum(0) - 1
mrope_pos_ids = mrope_pos_ids[:, index] + 1
return ModelInputs(
input_ids=input_ids,
seq_length=torch.full_like(inputs.seq_length, max_q_seqlen),
history_lengths=inputs.history_lengths + inputs.seq_length,
block_offsets=inputs.block_offsets,
is_decoding=True,
num_ignored_history=inputs.num_ignored_history.clone(),
max_q_seqlen=max_q_seqlen,
max_kv_seqlen=inputs.max_kv_seqlen + max_q_seqlen,
sum_kv_seqlen=inputs.sum_kv_seqlen + inputs.seq_length.numel() * inputs.max_q_seqlen,
local_adapter_ids=inputs.local_adapter_ids,
model_metas=model_metas,
state_offsets=state_offsets,
mrope_pos_ids=mrope_pos_ids,
)


@dataclass
class ARExtraInputs(ExtraInputs):
"""Ar extra inputs."""
Expand Down Expand Up @@ -145,26 +113,6 @@ def make_extra_outputs(self, extra_inputs: ARExtraInputs) -> ARExtraOutputs:
"""Create extra outputs."""
return ARExtraOutputs()

def update_prefill_for_next_step(
self,
model_inputs: 'ModelInputs',
extra_inputs: ARExtraInputs,
next_token_ids: torch.Tensor,
model_metas: Any,
extra_outputs: ARExtraOutputs,
) -> tuple['ModelInputs', ARExtraInputs]:
"""Step next decoding."""
inputs = get_model_inputs_next_decoding(model_inputs, next_token_ids, max_q_seqlen=1, model_metas=model_metas)
return inputs, extra_inputs

def update_decoding_for_next_step(self, model_inputs: 'ModelInputs', next_token_ids: torch.Tensor, model_metas: Any,
extra_inputs: ARExtraInputs, **kwargs):
"""Step next inputs."""
model_inputs.model_metas = model_metas
step_seqlens = model_inputs.seq_length
model_inputs = model_inputs.step(next_token_ids, step_seqlens)
return model_inputs, extra_inputs

def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor,
extra_inputs: ARExtraInputs):
"""Post sampling."""
Expand Down
Loading
Loading