Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class GenerationConfig:
logits_processors: list[LogitsProcessor] | None = None
output_logits: Literal['all', 'generation'] = None
output_last_hidden_state: Literal['all', 'generation'] = None
output_ppl: bool = False
include_stop_str_in_output: bool = False

# for disaggregation
Expand Down Expand Up @@ -496,6 +497,8 @@ class Response:
logprobs: list[dict[int, float]] = None
logits: torch.Tensor = None
last_hidden_state: torch.Tensor = None
ppl_loss: float = None
ppl_count: int = None
index: int = 0
routed_experts: Any = None

Expand Down Expand Up @@ -632,6 +635,8 @@ class EngineOutput:
logprobs: list[dict[int, float]] = None
logits: torch.Tensor = None
last_hidden_state: torch.Tensor = None
ppl_loss: float = None
ppl_count: int = None
cache_block_ids: list[int] | None = None
req_metrics: RequestMetrics | None = None
routed_experts: torch.Tensor = None
Expand Down
144 changes: 10 additions & 134 deletions lmdeploy/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from threading import Thread
from typing import TYPE_CHECKING

import torch
import tqdm
from typing_extensions import deprecated

Expand Down Expand Up @@ -262,8 +261,12 @@ def get_reward_score(self, input_ids: list) -> list[float]:
return scores

def get_ppl(self, input_ids: list[int] | list[list[int]]) -> list[float]:
"""Get perplexity scores given a list of input tokens that have to be
of the same length.
"""Get perplexity scores given a list of input tokens.

Cross-entropy is computed inline inside the engine (on GPU logits
chunks) so the full ``[seq_len, vocab_size]`` logits tensor is never
materialised on CPU, keeping memory usage constant regardless of
sequence length.

Args:
input_ids: the batch of input token ids.
Expand All @@ -276,40 +279,10 @@ def get_ppl(self, input_ids: list[int] | list[list[int]]) -> list[float]:
input_ids = [input_ids]
assert all(len(_) > 1 for _ in input_ids)

# TODO: a better way to determine `max_input_len`, at most allocate
# 2G mem for logits with shape [bs, max_input_len, vocab_size]
vocab_size = self.async_engine.hf_cfg.vocab_size
max_input_len = 2 * 1024**3 // (vocab_size * 4)
sizes = [len(_) for _ in input_ids]
result = []
sorted_index_values = sorted(list(enumerate(sizes)), key=lambda x: x[1], reverse=True)
sizes = [value for index, value in sorted_index_values]
indices = [index for index, value in sorted_index_values]
logger.info(f'sorted sizes: {sizes}')
logger.info(f'sorted indices: {indices}')
for (start, end) in self._batch_iterator(sizes, max_input_len):
logger.info(f'start: {start}, end: {end}')
if start == end:
_input_ids = input_ids[indices[start]]
session = self.session_mgr.get()
res = self._get_long_text_ppl(session, input_ids=_input_ids, max_input_len=max_input_len)
result.append(res)
self.session_mgr.remove(session)
else:
_input_ids = [input_ids[indices[i]] for i in range(start, end)]
sessions = [self.session_mgr.get() for _ in range(start, end)]
res = self._get_ppl(
sessions=sessions,
input_ids=_input_ids,
max_input_len=max_input_len,
)
result.extend(res)
for session in sessions:
self.session_mgr.remove(session)
output = list(range(len(result)))
for index, sorted_index in enumerate(indices):
output[sorted_index] = result[index]
return output
results = self._run(
coro=self.async_engine.async_get_ppl(input_ids=input_ids)
).result()
return [loss / count if count > 0 else 0.0 for loss, count in results]

def __call__(self,
prompts: list[str] | str | list[dict] | list[list[dict]],
Expand Down Expand Up @@ -429,103 +402,6 @@ async def _coro():
coro = _coro()
return asyncio.run_coroutine_threadsafe(coro, loop)

def _batch_iterator(self, sizes, max_value):
"""Return an iterator that calculates intervals (start, end) of a
descend-order list, in which the sum of values in the range is the
maximum number not less than max_value. By "the sum of values",

here it means $$len(sizes[start:end]) * sizes[start]$$
"""
i = 0
while i < len(sizes):
current_sum = 0
start_index = i

while i < len(sizes) and current_sum + sizes[start_index] <= max_value:
current_sum += sizes[start_index]
i += 1

yield (start_index, i)
if i > start_index:
continue
else:
i += 1

def _get_long_text_ppl(self, session, input_ids, max_input_len):
assert all(isinstance(_, int) for _ in input_ids)
seq_len = len(input_ids)
assert seq_len > max_input_len
logger.info(f'get long text ppl: seq_len {seq_len}')

losses = []
target_counts = []
for i in range(0, seq_len, max_input_len):
token_ids = input_ids[i:i + max_input_len]
session.update(step=i)
# shift token_ids by 1 to the left
target_ids = input_ids[i + 1:i + 1 + max_input_len]
loss = self._get_ppl(sessions=[session],
input_ids=[token_ids],
max_input_len=len(token_ids),
target_ids=[target_ids],
sequence_start=(i == 0),
sequence_end=False)
losses.extend(loss)
target_counts.append(len(target_ids))
losses = [loss * target_count for loss, target_count in zip(losses, target_counts)]
loss_sum = sum(losses)
target_count = sum(target_counts)
return loss_sum / target_count

def _get_ppl(self,
sessions: list[Session],
input_ids: list[list[int]],
max_input_len: int,
target_ids=None,
sequence_start: bool = True,
sequence_end: bool = True):
assert (isinstance(input_ids, list) and all(isinstance(_, list) for _ in input_ids))
assert target_ids is None or len(target_ids) == len(input_ids)
assert len(sessions) == len(input_ids)

lens = [len(_) for _ in input_ids]
total_len = sum(lens)
assert sum(lens) <= max_input_len

logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, '
f'total_len: {total_len}')
torch.cuda.empty_cache()

logits = self._run(coro=self.async_engine.async_get_logits(
input_ids=input_ids, sessions=sessions, sequence_start=sequence_start, sequence_end=sequence_end)).result()
padding_token_id = -100
if target_ids is None:
target_ids = [x[1:] + [padding_token_id] for x in input_ids]
else:
target_ids = [
target_ids[i] + [padding_token_id] if len(target_ids[i]) < len(input_ids[i]) else target_ids[i]
for i in range(len(input_ids))
]
target_ids = [torch.Tensor(torch.LongTensor(_target_ids)) for _target_ids in target_ids]

result = []
for _logits, _target_ids in zip(logits, target_ids):
_logits = _logits.float()
vocab_size = _logits.shape[-1]
_target_ids = _target_ids.to(_logits.device)
target_mask = _target_ids != padding_token_id
# compute cross entropy loss
flat_logits = _logits.contiguous().view(-1, vocab_size)
flat_target_ids = _target_ids.contiguous().view(-1)
flat_loss_matrix = torch.nn.functional.cross_entropy(flat_logits,
flat_target_ids,
reduction='none',
ignore_index=padding_token_id)
loss = flat_loss_matrix.sum()
target_count = target_mask.sum()
result.append(loss.item() / target_count.item())
logger.info(f'ppl result: {result}')
return result


class _EventLoopThread:
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class InferOutput:
finish: bool = False
logits: torch.Tensor = None
logprobs: torch.Tensor = None
ppl_loss: float = None
ppl_count: int = None

# send cache blocks back for migration in Disaggregated LLM Serving
# when Prefill Engine is Done.
Expand Down
6 changes: 6 additions & 0 deletions lmdeploy/pytorch/engine/engine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,20 @@ async def async_stream_infer(self,
# request might be cancelled before any output
token_ids = []
logits = None
ppl_loss = None
ppl_count = None
else:
token_ids = resp_data['token_ids'][output_offset:].tolist()
logits = resp_data.get('logits', None)
ppl_loss = resp_data.get('ppl_loss', None)
ppl_count = resp_data.get('ppl_count', None)
num_ids = len(token_ids) - output_offset
logger.debug(f'session[{session_id}] finish: num_out_ids={num_ids}.')
yield EngineOutput(resp.type,
token_ids,
logits=logits,
ppl_loss=ppl_loss,
ppl_count=ppl_count,
cache_block_ids=cache_block_ids,
req_metrics=req_metrics,
routed_experts=routed_experts,
Expand Down
14 changes: 14 additions & 0 deletions lmdeploy/pytorch/engine/engine_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def _send_resp(self, out: InferOutput):
resp_type,
data=dict(token_ids=out.token_ids,
logits=out.logits,
ppl_loss=out.ppl_loss,
ppl_count=out.ppl_count,
cache_block_ids=out.cache_block_ids,
req_metrics=out.req_metrics,
routed_experts=out.routed_experts,
Expand Down Expand Up @@ -236,6 +238,11 @@ def __get_logit(msg, logits: torch.Tensor, seq_length: list[int], idx: int):
seq = running[0]
seq.append_routed_experts(all_routed_experts)
seq.append_logits(logits)
if batched_outputs.ppl_losses is not None:
prev_loss = getattr(seq, '_ppl_loss_acc', 0.0)
prev_count = getattr(seq, '_ppl_count_acc', 0)
seq._ppl_loss_acc = prev_loss + batched_outputs.ppl_losses[0]
seq._ppl_count_acc = prev_count + batched_outputs.ppl_counts[0]
return dict()

new_token_timestamp = batched_outputs.new_token_timestamp
Expand Down Expand Up @@ -296,6 +303,13 @@ def __get_logit(msg, logits: torch.Tensor, seq_length: list[int], idx: int):
if msg.return_logits:
logit = __get_logit(msg, logits, seq_length, idx)
outputs[session_id].logits = logit
if batched_outputs.ppl_losses is not None:
ppl_loss = batched_outputs.ppl_losses[idx]
ppl_count = batched_outputs.ppl_counts[idx]
ppl_loss += getattr(msg, '_ppl_loss_acc', 0.0)
ppl_count += getattr(msg, '_ppl_count_acc', 0)
outputs[session_id].ppl_loss = ppl_loss
outputs[session_id].ppl_count = ppl_count
return outputs

async def _main_loop_try_send_next_inputs(self):
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/engine/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ class SamplingInputs:
repetition_ngram_threshold: torch.Tensor | None = None
max_repetition_ngram_size: int = 0

compute_ppl: bool = False

def to_device(self, device: str, non_blocking: bool = False):
"""To device."""
out_dict = dict()
Expand Down
59 changes: 56 additions & 3 deletions lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,41 @@
logger = get_logger('lmdeploy')


def _compute_ppl_from_logits(logits: torch.Tensor, input_ids: torch.Tensor,
seq_lengths: torch.Tensor) -> tuple[list[float], list[int]]:
"""Compute per-sequence cross-entropy loss from packed logits and
input_ids.

Args:
logits: [total_tokens, vocab_size] packed logits for all sequences.
input_ids: [total_tokens] packed input token ids.
seq_lengths: [batch_size] length of each sequence.

Returns:
(losses, counts): per-sequence summed CE loss and target token count.
"""
input_ids = input_ids.flatten()
losses = []
counts = []
offset = 0
for length in seq_lengths.tolist():
length = int(length)
if length <= 1:
losses.append(0.0)
counts.append(0)
offset += length
continue
seq_logits = logits[offset:offset + length - 1]
seq_targets = input_ids[offset + 1:offset + length]
print(f'seq_logits: {seq_logits.shape}, seq_targets: {seq_targets.shape}')
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a stray print(...) debug statement in _compute_ppl_from_logits that will spam stdout during inference and can severely impact throughput when called frequently. Please remove it or replace it with a guarded logger.debug if you still need shape diagnostics.

Suggested change
print(f'seq_logits: {seq_logits.shape}, seq_targets: {seq_targets.shape}')
logger.debug(
'seq_logits: %s, seq_targets: %s',
seq_logits.shape,
seq_targets.shape,
)

Copilot uses AI. Check for mistakes.
loss = torch.nn.functional.cross_entropy(
seq_logits.float(), seq_targets, reduction='sum')
losses.append(loss.item())
counts.append(length - 1)
offset += length
return losses, counts


@dataclass
class SleepWakeupState:
to_sleep: asyncio.Event = field(default_factory=asyncio.Event)
Expand Down Expand Up @@ -82,6 +117,8 @@ class BatchedOutputs:
new_token_timestamp: int = 0
extra_outputs: ExtraOutputs | None = None
all_routed_experts: torch.Tensor | None = None
ppl_losses: list[float] | None = None
ppl_counts: list[int] | None = None

def to_cpu(self):
"""To cpu."""
Expand Down Expand Up @@ -677,7 +714,9 @@ async def _step_postprocess_with_output(self,
need_broadcast_next: bool,
return_logits: bool = False,
all_routed_experts: Any = None,
extra_inputs: ExtraInputs = None):
extra_inputs: ExtraInputs = None,
ppl_losses: list[float] = None,
ppl_counts: list[int] = None):
"""Step postprocess with output."""
rank = self.rank
logger.debug(f'<ForwardTask> rank[{rank}]: Sampling.')
Expand Down Expand Up @@ -720,7 +759,9 @@ async def _step_postprocess_with_output(self,
model_metas=model_metas,
logprobs=logprobs,
all_routed_experts=all_routed_experts,
extra_outputs=extra_outputs))
extra_outputs=extra_outputs,
ppl_losses=ppl_losses,
ppl_counts=ppl_counts))

return inputs, extra_inputs, stopping_criteria, extra_outputs, next_token_ids

Expand Down Expand Up @@ -834,14 +875,18 @@ def __update_inputs(
# swap caches
cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map)

# PPL needs full logits for all positions, not just the last token
need_full_logits = (sampling_inputs.compute_ppl and not is_decoding)
forward_return_logits = return_logits or need_full_logits

# inference
logger.debug(f'<ForwardTask> rank[{rank}]: model forward. '
f'batch_size={inputs.seq_length.size(0)} '
f'num_tokens={inputs.input_ids.size(-1)} '
f'is_decoding={inputs.is_decoding}')
output = await self._async_model_forward(
inputs,
return_logits=return_logits,
return_logits=forward_return_logits,
)
# recovery is_decoding
inputs.is_decoding = is_decoding
Expand All @@ -856,6 +901,12 @@ def __update_inputs(
extra_inputs = self.agent_strategy.slice_extra_inputs(extra_inputs, inputs, output)
model_metas = output.get('model_metas')

ppl_losses = None
ppl_counts = None
if sampling_inputs.compute_ppl and not inputs.is_decoding:
ppl_losses, ppl_counts = _compute_ppl_from_logits(
logits, inputs.input_ids, seq_length)

if self.need_output:
logger.debug(f'<ForwardTask> rank[{rank}]: Sampling.')
# for router replay
Expand All @@ -881,6 +932,8 @@ def __update_inputs(
return_logits=return_logits,
all_routed_experts=all_routed_experts,
extra_inputs=extra_inputs,
ppl_losses=ppl_losses,
ppl_counts=ppl_counts,
)
else:
(
Expand Down
Loading
Loading