Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
lmdeploy check_env
- name: Test lmdeploy python UT
run: |
coverage run --branch --source lmdeploy -m pytest -rsE tests
coverage run --branch --source lmdeploy -m pytest -vv tests
coverage xml
coverage report -m
- name: Clear workfile
Expand Down
5 changes: 5 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class GenerationConfig:
logits_processors: list[LogitsProcessor] | None = None
output_logits: Literal['all', 'generation'] = None
output_last_hidden_state: Literal['all', 'generation'] = None
output_ppl: bool = False
include_stop_str_in_output: bool = False

# for disaggregation
Expand Down Expand Up @@ -496,6 +497,8 @@ class Response:
logprobs: list[dict[int, float]] = None
logits: torch.Tensor = None
last_hidden_state: torch.Tensor = None
ppl_loss: float = None
ppl_count: int = None
index: int = 0
routed_experts: Any = None

Expand Down Expand Up @@ -632,6 +635,8 @@ class EngineOutput:
logprobs: list[dict[int, float]] = None
logits: torch.Tensor = None
last_hidden_state: torch.Tensor = None
ppl_loss: float = None
ppl_count: int = None
cache_block_ids: list[int] | None = None
req_metrics: RequestMetrics | None = None
routed_experts: torch.Tensor = None
Expand Down
144 changes: 10 additions & 134 deletions lmdeploy/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from threading import Thread
from typing import TYPE_CHECKING

import torch
import tqdm
from typing_extensions import deprecated

Expand Down Expand Up @@ -262,8 +261,12 @@ def get_reward_score(self, input_ids: list) -> list[float]:
return scores

def get_ppl(self, input_ids: list[int] | list[list[int]]) -> list[float]:
"""Get perplexity scores given a list of input tokens that have to be
of the same length.
"""Get perplexity scores given a list of input tokens.

Cross-entropy is computed inline inside the engine (on GPU logits
chunks) so the full ``[seq_len, vocab_size]`` logits tensor is never
materialised on CPU, keeping memory usage constant regardless of
sequence length.

Args:
input_ids: the batch of input token ids.
Expand All @@ -276,40 +279,10 @@ def get_ppl(self, input_ids: list[int] | list[list[int]]) -> list[float]:
input_ids = [input_ids]
assert all(len(_) > 1 for _ in input_ids)

# TODO: a better way to determine `max_input_len`, at most allocate
# 2G mem for logits with shape [bs, max_input_len, vocab_size]
vocab_size = self.async_engine.hf_cfg.vocab_size
max_input_len = 2 * 1024**3 // (vocab_size * 4)
sizes = [len(_) for _ in input_ids]
result = []
sorted_index_values = sorted(list(enumerate(sizes)), key=lambda x: x[1], reverse=True)
sizes = [value for index, value in sorted_index_values]
indices = [index for index, value in sorted_index_values]
logger.info(f'sorted sizes: {sizes}')
logger.info(f'sorted indices: {indices}')
for (start, end) in self._batch_iterator(sizes, max_input_len):
logger.info(f'start: {start}, end: {end}')
if start == end:
_input_ids = input_ids[indices[start]]
session = self.session_mgr.get()
res = self._get_long_text_ppl(session, input_ids=_input_ids, max_input_len=max_input_len)
result.append(res)
self.session_mgr.remove(session)
else:
_input_ids = [input_ids[indices[i]] for i in range(start, end)]
sessions = [self.session_mgr.get() for _ in range(start, end)]
res = self._get_ppl(
sessions=sessions,
input_ids=_input_ids,
max_input_len=max_input_len,
)
result.extend(res)
for session in sessions:
self.session_mgr.remove(session)
output = list(range(len(result)))
for index, sorted_index in enumerate(indices):
output[sorted_index] = result[index]
return output
results = self._run(
coro=self.async_engine.async_get_ppl(input_ids=input_ids)
).result()
return [loss / count if count > 0 else 0.0 for loss, count in results]

def __call__(self,
prompts: list[str] | str | list[dict] | list[list[dict]],
Expand Down Expand Up @@ -429,103 +402,6 @@ async def _coro():
coro = _coro()
return asyncio.run_coroutine_threadsafe(coro, loop)

def _batch_iterator(self, sizes, max_value):
"""Return an iterator that calculates intervals (start, end) of a
descend-order list, in which the sum of values in the range is the
maximum number not less than max_value. By "the sum of values",

here it means $$len(sizes[start:end]) * sizes[start]$$
"""
i = 0
while i < len(sizes):
current_sum = 0
start_index = i

while i < len(sizes) and current_sum + sizes[start_index] <= max_value:
current_sum += sizes[start_index]
i += 1

yield (start_index, i)
if i > start_index:
continue
else:
i += 1

def _get_long_text_ppl(self, session, input_ids, max_input_len):
assert all(isinstance(_, int) for _ in input_ids)
seq_len = len(input_ids)
assert seq_len > max_input_len
logger.info(f'get long text ppl: seq_len {seq_len}')

losses = []
target_counts = []
for i in range(0, seq_len, max_input_len):
token_ids = input_ids[i:i + max_input_len]
session.update(step=i)
# shift token_ids by 1 to the left
target_ids = input_ids[i + 1:i + 1 + max_input_len]
loss = self._get_ppl(sessions=[session],
input_ids=[token_ids],
max_input_len=len(token_ids),
target_ids=[target_ids],
sequence_start=(i == 0),
sequence_end=False)
losses.extend(loss)
target_counts.append(len(target_ids))
losses = [loss * target_count for loss, target_count in zip(losses, target_counts)]
loss_sum = sum(losses)
target_count = sum(target_counts)
return loss_sum / target_count

def _get_ppl(self,
sessions: list[Session],
input_ids: list[list[int]],
max_input_len: int,
target_ids=None,
sequence_start: bool = True,
sequence_end: bool = True):
assert (isinstance(input_ids, list) and all(isinstance(_, list) for _ in input_ids))
assert target_ids is None or len(target_ids) == len(input_ids)
assert len(sessions) == len(input_ids)

lens = [len(_) for _ in input_ids]
total_len = sum(lens)
assert sum(lens) <= max_input_len

logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, '
f'total_len: {total_len}')
torch.cuda.empty_cache()

logits = self._run(coro=self.async_engine.async_get_logits(
input_ids=input_ids, sessions=sessions, sequence_start=sequence_start, sequence_end=sequence_end)).result()
padding_token_id = -100
if target_ids is None:
target_ids = [x[1:] + [padding_token_id] for x in input_ids]
else:
target_ids = [
target_ids[i] + [padding_token_id] if len(target_ids[i]) < len(input_ids[i]) else target_ids[i]
for i in range(len(input_ids))
]
target_ids = [torch.Tensor(torch.LongTensor(_target_ids)) for _target_ids in target_ids]

result = []
for _logits, _target_ids in zip(logits, target_ids):
_logits = _logits.float()
vocab_size = _logits.shape[-1]
_target_ids = _target_ids.to(_logits.device)
target_mask = _target_ids != padding_token_id
# compute cross entropy loss
flat_logits = _logits.contiguous().view(-1, vocab_size)
flat_target_ids = _target_ids.contiguous().view(-1)
flat_loss_matrix = torch.nn.functional.cross_entropy(flat_logits,
flat_target_ids,
reduction='none',
ignore_index=padding_token_id)
loss = flat_loss_matrix.sum()
target_count = target_mask.sum()
result.append(loss.item() / target_count.item())
logger.info(f'ppl result: {result}')
return result


class _EventLoopThread:
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class InferOutput:
finish: bool = False
logits: torch.Tensor = None
logprobs: torch.Tensor = None
ppl_loss: float = None
ppl_count: int = None

# send cache blocks back for migration in Disaggregated LLM Serving
# when Prefill Engine is Done.
Expand Down
6 changes: 6 additions & 0 deletions lmdeploy/pytorch/engine/engine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,20 @@ async def async_stream_infer(self,
# request might be cancelled before any output
token_ids = []
logits = None
ppl_loss = None
ppl_count = None
else:
token_ids = resp_data['token_ids'][output_offset:].tolist()
logits = resp_data.get('logits', None)
ppl_loss = resp_data.get('ppl_loss', None)
ppl_count = resp_data.get('ppl_count', None)
num_ids = len(token_ids) - output_offset
logger.debug(f'session[{session_id}] finish: num_out_ids={num_ids}.')
yield EngineOutput(resp.type,
token_ids,
logits=logits,
ppl_loss=ppl_loss,
ppl_count=ppl_count,
cache_block_ids=cache_block_ids,
req_metrics=req_metrics,
routed_experts=routed_experts,
Expand Down
14 changes: 14 additions & 0 deletions lmdeploy/pytorch/engine/engine_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def _send_resp(self, out: InferOutput):
resp_type,
data=dict(token_ids=out.token_ids,
logits=out.logits,
ppl_loss=out.ppl_loss,
ppl_count=out.ppl_count,
cache_block_ids=out.cache_block_ids,
req_metrics=out.req_metrics,
routed_experts=out.routed_experts,
Expand Down Expand Up @@ -236,6 +238,11 @@ def __get_logit(msg, logits: torch.Tensor, seq_length: list[int], idx: int):
seq = running[0]
seq.append_routed_experts(all_routed_experts)
seq.append_logits(logits)
if batched_outputs.ppl_losses is not None:
prev_loss = getattr(seq, '_ppl_loss_acc', 0.0)
prev_count = getattr(seq, '_ppl_count_acc', 0)
seq._ppl_loss_acc = prev_loss + batched_outputs.ppl_losses[0]
seq._ppl_count_acc = prev_count + batched_outputs.ppl_counts[0]
return dict()

new_token_timestamp = batched_outputs.new_token_timestamp
Expand Down Expand Up @@ -296,6 +303,13 @@ def __get_logit(msg, logits: torch.Tensor, seq_length: list[int], idx: int):
if msg.return_logits:
logit = __get_logit(msg, logits, seq_length, idx)
outputs[session_id].logits = logit
if batched_outputs.ppl_losses is not None:
ppl_loss = batched_outputs.ppl_losses[idx]
ppl_count = batched_outputs.ppl_counts[idx]
ppl_loss += getattr(msg, '_ppl_loss_acc', 0.0)
ppl_count += getattr(msg, '_ppl_count_acc', 0)
outputs[session_id].ppl_loss = ppl_loss
outputs[session_id].ppl_count = ppl_count
return outputs

async def _main_loop_try_send_next_inputs(self):
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/engine/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ class SamplingInputs:
repetition_ngram_threshold: torch.Tensor | None = None
max_repetition_ngram_size: int = 0

compute_ppl: bool = False

def to_device(self, device: str, non_blocking: bool = False):
"""To device."""
out_dict = dict()
Expand Down
58 changes: 55 additions & 3 deletions lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,40 @@
logger = get_logger('lmdeploy')


def _compute_ppl_from_logits(logits: torch.Tensor, input_ids: torch.Tensor,
seq_lengths: torch.Tensor) -> tuple[list[float], list[int]]:
"""Compute per-sequence cross-entropy loss from packed logits and
input_ids.

Args:
logits: [total_tokens, vocab_size] packed logits for all sequences.
input_ids: [total_tokens] packed input token ids.
seq_lengths: [batch_size] length of each sequence.

Returns:
(losses, counts): per-sequence summed CE loss and target token count.
"""
input_ids = input_ids.flatten()
losses = []
counts = []
offset = 0
for length in seq_lengths.tolist():
length = int(length)
if length <= 1:
losses.append(0.0)
counts.append(0)
offset += length
continue
seq_logits = logits[offset:offset + length - 1]
seq_targets = input_ids[offset + 1:offset + length]
loss = torch.nn.functional.cross_entropy(
seq_logits.float(), seq_targets, reduction='sum')
losses.append(loss.item())
counts.append(length - 1)
offset += length
return losses, counts


@dataclass
class SleepWakeupState:
to_sleep: asyncio.Event = field(default_factory=asyncio.Event)
Expand Down Expand Up @@ -82,6 +116,8 @@ class BatchedOutputs:
new_token_timestamp: int = 0
extra_outputs: ExtraOutputs | None = None
all_routed_experts: torch.Tensor | None = None
ppl_losses: list[float] | None = None
ppl_counts: list[int] | None = None

def to_cpu(self):
"""To cpu."""
Expand Down Expand Up @@ -677,7 +713,9 @@ async def _step_postprocess_with_output(self,
need_broadcast_next: bool,
return_logits: bool = False,
all_routed_experts: Any = None,
extra_inputs: ExtraInputs = None):
extra_inputs: ExtraInputs = None,
ppl_losses: list[float] = None,
ppl_counts: list[int] = None):
"""Step postprocess with output."""
rank = self.rank
logger.debug(f'<ForwardTask> rank[{rank}]: Sampling.')
Expand Down Expand Up @@ -720,7 +758,9 @@ async def _step_postprocess_with_output(self,
model_metas=model_metas,
logprobs=logprobs,
all_routed_experts=all_routed_experts,
extra_outputs=extra_outputs))
extra_outputs=extra_outputs,
ppl_losses=ppl_losses,
ppl_counts=ppl_counts))

return inputs, extra_inputs, stopping_criteria, extra_outputs, next_token_ids

Expand Down Expand Up @@ -834,14 +874,18 @@ def __update_inputs(
# swap caches
cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map)

# PPL needs full logits for all positions, not just the last token
need_full_logits = (sampling_inputs.compute_ppl and not is_decoding)
forward_return_logits = return_logits or need_full_logits

# inference
logger.debug(f'<ForwardTask> rank[{rank}]: model forward. '
f'batch_size={inputs.seq_length.size(0)} '
f'num_tokens={inputs.input_ids.size(-1)} '
f'is_decoding={inputs.is_decoding}')
output = await self._async_model_forward(
inputs,
return_logits=return_logits,
return_logits=forward_return_logits,
)
# recovery is_decoding
inputs.is_decoding = is_decoding
Expand All @@ -856,6 +900,12 @@ def __update_inputs(
extra_inputs = self.agent_strategy.slice_extra_inputs(extra_inputs, inputs, output)
model_metas = output.get('model_metas')

ppl_losses = None
ppl_counts = None
if sampling_inputs.compute_ppl and not inputs.is_decoding:
ppl_losses, ppl_counts = _compute_ppl_from_logits(
logits, inputs.input_ids, seq_length)

if self.need_output:
logger.debug(f'<ForwardTask> rank[{rank}]: Sampling.')
# for router replay
Expand All @@ -881,6 +931,8 @@ def __update_inputs(
return_logits=return_logits,
all_routed_experts=all_routed_experts,
extra_inputs=extra_inputs,
ppl_losses=ppl_losses,
ppl_counts=ppl_counts,
)
else:
(
Expand Down
Loading
Loading