diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index cc19ae1be1..c17a49b90a 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -55,7 +55,7 @@ jobs: lmdeploy check_env - name: Test lmdeploy python UT run: | - coverage run --branch --source lmdeploy -m pytest -rsE tests + coverage run --branch --source lmdeploy -m pytest -vv tests coverage xml coverage report -m - name: Clear workfile diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index d6cd1a3329..ee8b8fa6e2 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -122,6 +122,7 @@ class GenerationConfig: logits_processors: list[LogitsProcessor] | None = None output_logits: Literal['all', 'generation'] = None output_last_hidden_state: Literal['all', 'generation'] = None + output_ppl: bool = False include_stop_str_in_output: bool = False # for disaggregation @@ -496,6 +497,8 @@ class Response: logprobs: list[dict[int, float]] = None logits: torch.Tensor = None last_hidden_state: torch.Tensor = None + ppl_loss: float = None + ppl_count: int = None index: int = 0 routed_experts: Any = None @@ -632,6 +635,8 @@ class EngineOutput: logprobs: list[dict[int, float]] = None logits: torch.Tensor = None last_hidden_state: torch.Tensor = None + ppl_loss: float = None + ppl_count: int = None cache_block_ids: list[int] | None = None req_metrics: RequestMetrics | None = None routed_experts: torch.Tensor = None diff --git a/lmdeploy/pipeline.py b/lmdeploy/pipeline.py index ca4c42bba0..b572756aff 100644 --- a/lmdeploy/pipeline.py +++ b/lmdeploy/pipeline.py @@ -12,7 +12,6 @@ from threading import Thread from typing import TYPE_CHECKING -import torch import tqdm from typing_extensions import deprecated @@ -262,8 +261,12 @@ def get_reward_score(self, input_ids: list) -> list[float]: return scores def get_ppl(self, input_ids: list[int] | list[list[int]]) -> list[float]: - """Get perplexity scores given a list of input tokens that have to be - of the same length. + """Get perplexity scores given a list of input tokens. + + Cross-entropy is computed inline inside the engine (on GPU logits + chunks) so the full ``[seq_len, vocab_size]`` logits tensor is never + materialised on CPU, keeping memory usage constant regardless of + sequence length. Args: input_ids: the batch of input token ids. @@ -276,40 +279,10 @@ def get_ppl(self, input_ids: list[int] | list[list[int]]) -> list[float]: input_ids = [input_ids] assert all(len(_) > 1 for _ in input_ids) - # TODO: a better way to determine `max_input_len`, at most allocate - # 2G mem for logits with shape [bs, max_input_len, vocab_size] - vocab_size = self.async_engine.hf_cfg.vocab_size - max_input_len = 2 * 1024**3 // (vocab_size * 4) - sizes = [len(_) for _ in input_ids] - result = [] - sorted_index_values = sorted(list(enumerate(sizes)), key=lambda x: x[1], reverse=True) - sizes = [value for index, value in sorted_index_values] - indices = [index for index, value in sorted_index_values] - logger.info(f'sorted sizes: {sizes}') - logger.info(f'sorted indices: {indices}') - for (start, end) in self._batch_iterator(sizes, max_input_len): - logger.info(f'start: {start}, end: {end}') - if start == end: - _input_ids = input_ids[indices[start]] - session = self.session_mgr.get() - res = self._get_long_text_ppl(session, input_ids=_input_ids, max_input_len=max_input_len) - result.append(res) - self.session_mgr.remove(session) - else: - _input_ids = [input_ids[indices[i]] for i in range(start, end)] - sessions = [self.session_mgr.get() for _ in range(start, end)] - res = self._get_ppl( - sessions=sessions, - input_ids=_input_ids, - max_input_len=max_input_len, - ) - result.extend(res) - for session in sessions: - self.session_mgr.remove(session) - output = list(range(len(result))) - for index, sorted_index in enumerate(indices): - output[sorted_index] = result[index] - return output + results = self._run( + coro=self.async_engine.async_get_ppl(input_ids=input_ids) + ).result() + return [loss / count if count > 0 else 0.0 for loss, count in results] def __call__(self, prompts: list[str] | str | list[dict] | list[list[dict]], @@ -429,103 +402,6 @@ async def _coro(): coro = _coro() return asyncio.run_coroutine_threadsafe(coro, loop) - def _batch_iterator(self, sizes, max_value): - """Return an iterator that calculates intervals (start, end) of a - descend-order list, in which the sum of values in the range is the - maximum number not less than max_value. By "the sum of values", - - here it means $$len(sizes[start:end]) * sizes[start]$$ - """ - i = 0 - while i < len(sizes): - current_sum = 0 - start_index = i - - while i < len(sizes) and current_sum + sizes[start_index] <= max_value: - current_sum += sizes[start_index] - i += 1 - - yield (start_index, i) - if i > start_index: - continue - else: - i += 1 - - def _get_long_text_ppl(self, session, input_ids, max_input_len): - assert all(isinstance(_, int) for _ in input_ids) - seq_len = len(input_ids) - assert seq_len > max_input_len - logger.info(f'get long text ppl: seq_len {seq_len}') - - losses = [] - target_counts = [] - for i in range(0, seq_len, max_input_len): - token_ids = input_ids[i:i + max_input_len] - session.update(step=i) - # shift token_ids by 1 to the left - target_ids = input_ids[i + 1:i + 1 + max_input_len] - loss = self._get_ppl(sessions=[session], - input_ids=[token_ids], - max_input_len=len(token_ids), - target_ids=[target_ids], - sequence_start=(i == 0), - sequence_end=False) - losses.extend(loss) - target_counts.append(len(target_ids)) - losses = [loss * target_count for loss, target_count in zip(losses, target_counts)] - loss_sum = sum(losses) - target_count = sum(target_counts) - return loss_sum / target_count - - def _get_ppl(self, - sessions: list[Session], - input_ids: list[list[int]], - max_input_len: int, - target_ids=None, - sequence_start: bool = True, - sequence_end: bool = True): - assert (isinstance(input_ids, list) and all(isinstance(_, list) for _ in input_ids)) - assert target_ids is None or len(target_ids) == len(input_ids) - assert len(sessions) == len(input_ids) - - lens = [len(_) for _ in input_ids] - total_len = sum(lens) - assert sum(lens) <= max_input_len - - logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, ' - f'total_len: {total_len}') - torch.cuda.empty_cache() - - logits = self._run(coro=self.async_engine.async_get_logits( - input_ids=input_ids, sessions=sessions, sequence_start=sequence_start, sequence_end=sequence_end)).result() - padding_token_id = -100 - if target_ids is None: - target_ids = [x[1:] + [padding_token_id] for x in input_ids] - else: - target_ids = [ - target_ids[i] + [padding_token_id] if len(target_ids[i]) < len(input_ids[i]) else target_ids[i] - for i in range(len(input_ids)) - ] - target_ids = [torch.Tensor(torch.LongTensor(_target_ids)) for _target_ids in target_ids] - - result = [] - for _logits, _target_ids in zip(logits, target_ids): - _logits = _logits.float() - vocab_size = _logits.shape[-1] - _target_ids = _target_ids.to(_logits.device) - target_mask = _target_ids != padding_token_id - # compute cross entropy loss - flat_logits = _logits.contiguous().view(-1, vocab_size) - flat_target_ids = _target_ids.contiguous().view(-1) - flat_loss_matrix = torch.nn.functional.cross_entropy(flat_logits, - flat_target_ids, - reduction='none', - ignore_index=padding_token_id) - loss = flat_loss_matrix.sum() - target_count = target_mask.sum() - result.append(loss.item() / target_count.item()) - logger.info(f'ppl result: {result}') - return result class _EventLoopThread: diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 27848de026..b1d2ba8bfe 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -45,6 +45,8 @@ class InferOutput: finish: bool = False logits: torch.Tensor = None logprobs: torch.Tensor = None + ppl_loss: float = None + ppl_count: int = None # send cache blocks back for migration in Disaggregated LLM Serving # when Prefill Engine is Done. diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index c2bbc03420..428c0f8ae3 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -190,14 +190,20 @@ async def async_stream_infer(self, # request might be cancelled before any output token_ids = [] logits = None + ppl_loss = None + ppl_count = None else: token_ids = resp_data['token_ids'][output_offset:].tolist() logits = resp_data.get('logits', None) + ppl_loss = resp_data.get('ppl_loss', None) + ppl_count = resp_data.get('ppl_count', None) num_ids = len(token_ids) - output_offset logger.debug(f'session[{session_id}] finish: num_out_ids={num_ids}.') yield EngineOutput(resp.type, token_ids, logits=logits, + ppl_loss=ppl_loss, + ppl_count=ppl_count, cache_block_ids=cache_block_ids, req_metrics=req_metrics, routed_experts=routed_experts, diff --git a/lmdeploy/pytorch/engine/engine_loop.py b/lmdeploy/pytorch/engine/engine_loop.py index 954f6cc4c7..dcdba698cb 100644 --- a/lmdeploy/pytorch/engine/engine_loop.py +++ b/lmdeploy/pytorch/engine/engine_loop.py @@ -160,6 +160,8 @@ def _send_resp(self, out: InferOutput): resp_type, data=dict(token_ids=out.token_ids, logits=out.logits, + ppl_loss=out.ppl_loss, + ppl_count=out.ppl_count, cache_block_ids=out.cache_block_ids, req_metrics=out.req_metrics, routed_experts=out.routed_experts, @@ -236,6 +238,11 @@ def __get_logit(msg, logits: torch.Tensor, seq_length: list[int], idx: int): seq = running[0] seq.append_routed_experts(all_routed_experts) seq.append_logits(logits) + if batched_outputs.ppl_losses is not None: + prev_loss = getattr(seq, '_ppl_loss_acc', 0.0) + prev_count = getattr(seq, '_ppl_count_acc', 0) + seq._ppl_loss_acc = prev_loss + batched_outputs.ppl_losses[0] + seq._ppl_count_acc = prev_count + batched_outputs.ppl_counts[0] return dict() new_token_timestamp = batched_outputs.new_token_timestamp @@ -296,6 +303,13 @@ def __get_logit(msg, logits: torch.Tensor, seq_length: list[int], idx: int): if msg.return_logits: logit = __get_logit(msg, logits, seq_length, idx) outputs[session_id].logits = logit + if batched_outputs.ppl_losses is not None: + ppl_loss = batched_outputs.ppl_losses[idx] + ppl_count = batched_outputs.ppl_counts[idx] + ppl_loss += getattr(msg, '_ppl_loss_acc', 0.0) + ppl_count += getattr(msg, '_ppl_count_acc', 0) + outputs[session_id].ppl_loss = ppl_loss + outputs[session_id].ppl_count = ppl_count return outputs async def _main_loop_try_send_next_inputs(self): diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index f6f290fc29..8de03f49b2 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -271,6 +271,8 @@ class SamplingInputs: repetition_ngram_threshold: torch.Tensor | None = None max_repetition_ngram_size: int = 0 + compute_ppl: bool = False + def to_device(self, device: str, non_blocking: bool = False): """To device.""" out_dict = dict() diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index 6db690ab7f..959280e56c 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -38,6 +38,40 @@ logger = get_logger('lmdeploy') +def _compute_ppl_from_logits(logits: torch.Tensor, input_ids: torch.Tensor, + seq_lengths: torch.Tensor) -> tuple[list[float], list[int]]: + """Compute per-sequence cross-entropy loss from packed logits and + input_ids. + + Args: + logits: [total_tokens, vocab_size] packed logits for all sequences. + input_ids: [total_tokens] packed input token ids. + seq_lengths: [batch_size] length of each sequence. + + Returns: + (losses, counts): per-sequence summed CE loss and target token count. + """ + input_ids = input_ids.flatten() + losses = [] + counts = [] + offset = 0 + for length in seq_lengths.tolist(): + length = int(length) + if length <= 1: + losses.append(0.0) + counts.append(0) + offset += length + continue + seq_logits = logits[offset:offset + length - 1] + seq_targets = input_ids[offset + 1:offset + length] + loss = torch.nn.functional.cross_entropy( + seq_logits.float(), seq_targets, reduction='sum') + losses.append(loss.item()) + counts.append(length - 1) + offset += length + return losses, counts + + @dataclass class SleepWakeupState: to_sleep: asyncio.Event = field(default_factory=asyncio.Event) @@ -82,6 +116,8 @@ class BatchedOutputs: new_token_timestamp: int = 0 extra_outputs: ExtraOutputs | None = None all_routed_experts: torch.Tensor | None = None + ppl_losses: list[float] | None = None + ppl_counts: list[int] | None = None def to_cpu(self): """To cpu.""" @@ -677,7 +713,9 @@ async def _step_postprocess_with_output(self, need_broadcast_next: bool, return_logits: bool = False, all_routed_experts: Any = None, - extra_inputs: ExtraInputs = None): + extra_inputs: ExtraInputs = None, + ppl_losses: list[float] = None, + ppl_counts: list[int] = None): """Step postprocess with output.""" rank = self.rank logger.debug(f' rank[{rank}]: Sampling.') @@ -720,7 +758,9 @@ async def _step_postprocess_with_output(self, model_metas=model_metas, logprobs=logprobs, all_routed_experts=all_routed_experts, - extra_outputs=extra_outputs)) + extra_outputs=extra_outputs, + ppl_losses=ppl_losses, + ppl_counts=ppl_counts)) return inputs, extra_inputs, stopping_criteria, extra_outputs, next_token_ids @@ -834,6 +874,10 @@ def __update_inputs( # swap caches cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) + # PPL needs full logits for all positions, not just the last token + need_full_logits = (sampling_inputs.compute_ppl and not is_decoding) + forward_return_logits = return_logits or need_full_logits + # inference logger.debug(f' rank[{rank}]: model forward. ' f'batch_size={inputs.seq_length.size(0)} ' @@ -841,7 +885,7 @@ def __update_inputs( f'is_decoding={inputs.is_decoding}') output = await self._async_model_forward( inputs, - return_logits=return_logits, + return_logits=forward_return_logits, ) # recovery is_decoding inputs.is_decoding = is_decoding @@ -856,6 +900,12 @@ def __update_inputs( extra_inputs = self.agent_strategy.slice_extra_inputs(extra_inputs, inputs, output) model_metas = output.get('model_metas') + ppl_losses = None + ppl_counts = None + if sampling_inputs.compute_ppl and not inputs.is_decoding: + ppl_losses, ppl_counts = _compute_ppl_from_logits( + logits, inputs.input_ids, seq_length) + if self.need_output: logger.debug(f' rank[{rank}]: Sampling.') # for router replay @@ -881,6 +931,8 @@ def __update_inputs( return_logits=return_logits, all_routed_experts=all_routed_experts, extra_inputs=extra_inputs, + ppl_losses=ppl_losses, + ppl_counts=ppl_counts, ) else: ( diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 218cfb5113..03f1f939e0 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -60,6 +60,7 @@ class SamplingParam: logits_processors: None | list[LogitsProcessor] = None out_logits: bool = False out_last_hidden_states: bool = False + compute_ppl: bool = False num_logprobs: int = -1 return_routed_experts: bool = False @@ -146,6 +147,7 @@ def from_gen_config(cls, gen_config: GenerationConfig): min_new_tokens=min_new_tokens, logits_processors=gen_config.logits_processors, out_logits=(output_logits is not None), + compute_ppl=gen_config.output_ppl, num_logprobs=logprobs, return_routed_experts=gen_config.return_routed_experts, repetition_ngram_size=gen_config.repetition_ngram_size, diff --git a/lmdeploy/pytorch/strategies/ar/sampling.py b/lmdeploy/pytorch/strategies/ar/sampling.py index 8a8c45d032..b178c41d8a 100644 --- a/lmdeploy/pytorch/strategies/ar/sampling.py +++ b/lmdeploy/pytorch/strategies/ar/sampling.py @@ -196,6 +196,8 @@ def __get_bad_words(bad_words): if repetition_ngram_same_n: repetition_ngram_sizes = None + compute_ppl = any(seq.sampling_param.compute_ppl for seq in seqs) + sampling_input = SamplingInputs( temperature=temperature, bad_words=bad_words, @@ -219,6 +221,7 @@ def __get_bad_words(bad_words): repetition_ngram_size=repetition_ngram_sizes, repetition_ngram_threshold=repetition_ngram_thresholds, max_repetition_ngram_size=max_repetition_ngram_size, + compute_ppl=compute_ppl, ) pad_token_id = self.pad_token_id diff --git a/lmdeploy/serve/core/async_engine.py b/lmdeploy/serve/core/async_engine.py index 114775fc69..86cddcd253 100644 --- a/lmdeploy/serve/core/async_engine.py +++ b/lmdeploy/serve/core/async_engine.py @@ -51,6 +51,8 @@ class GenOut: logprobs: list[dict[int, float]] | None = None logits: Any = None last_hidden_state: Any = None + ppl_loss: float = None + ppl_count: int = None cache_block_ids: list[int] | None = None # for disaggregation routed_experts: Any = None # for RL router replay @@ -68,6 +70,8 @@ def to_response(self, index: int = 0) -> Response: logprobs=self.logprobs, last_hidden_state=self.last_hidden_state, logits=self.logits, + ppl_loss=self.ppl_loss, + ppl_count=self.ppl_count, routed_experts=self.routed_experts, index=index) @@ -490,6 +494,8 @@ def is_error(status): gen_len, finish_reason, token_ids=res, + ppl_loss=outputs.ppl_loss, + ppl_count=outputs.ppl_count, routed_experts=outputs.routed_experts, cache_block_ids=outputs.cache_block_ids) if outputs.logprobs is not None: @@ -541,6 +547,8 @@ def is_error(status): logprobs=logprobs, logits=logits, last_hidden_state=last_hidden_state, + ppl_loss=outputs.ppl_loss, + ppl_count=outputs.ppl_count, routed_experts=routed_experts, cache_block_ids=outputs.cache_block_ids) # Note: We remove the session step update here. Let the caller(e.g., pipeline.chat) take care of it. @@ -673,3 +681,49 @@ async def _proc(session, i): for session in sessions: self.session_mgr.remove(session) return logits + + async def async_get_ppl(self, + input_ids, + sessions: list['Session'] | None = None, + sequence_start: bool = True, + sequence_end: bool = True) -> list[tuple[float, int]]: + """Compute per-sequence perplexity (loss, count) pairs. + + The engine computes cross-entropy inline and only returns scalar loss/count, avoiding materialisation of the + full logits tensor. + """ + assert input_ids and all(isinstance(_, list) for _ in input_ids) + assert sessions is None or (len(sessions) == len(input_ids)) + + results: list[tuple[float, int]] = [None] * len(input_ids) + + async def _proc(session, i): + async with session.request_handle() as handle: + max_new_tokens = 1 if self.backend == 'turbomind' else 0 + gen_config = GenerationConfig( + max_new_tokens=max_new_tokens, output_ppl=True, top_k=1) + async with self.safe_run(handle, + session=session, + input_ids=input_ids[i], + gen_config=gen_config, + stream_output=False, + sequence_start=sequence_start, + sequence_end=sequence_end, + step=session.step) as gen: + async for outputs in gen: + pass + results[i] = (outputs.ppl_loss, outputs.ppl_count) + + create_sessions = False + if sessions is None: + create_sessions = True + sessions = [self.session_mgr.get() for _ in range(len(input_ids))] + tasks = [_proc(session, i) for i, session in enumerate(sessions)] + await asyncio.gather(*tasks) + if sequence_end and self.backend == 'pytorch': + for session in sessions: + await session.async_close() + if sequence_end and create_sessions: + for session in sessions: + self.session_mgr.remove(session) + return results diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index f95b2b93ca..be94887023 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -735,8 +735,26 @@ async def async_stream_infer(self, sem = StreamingSemaphore() signal_cb = partial(self.async_signal_cb, sem) + logits_cb = None + ppl_state = None + if gen_config.output_ppl: + ppl_state = [0.0, 0] # [accumulated_loss, accumulated_count] + + def logits_cb(logits_chunk, vocab_size, begin, count): + targets = torch.tensor( + input_ids[begin + 1:begin + count + 1], + device=logits_chunk.device) + valid = min(count, len(targets)) + if valid > 0: + loss = torch.nn.functional.cross_entropy( + logits_chunk[:valid].float(), targets[:valid], + reduction='sum') + ppl_state[0] += loss.item() + ppl_state[1] += valid + outputs, shared_state, metrics = self.model_inst.forward(inputs, session, gen_cfg, stream_output, - self.tm_model.engine_config.enable_metrics, signal_cb) + self.tm_model.engine_config.enable_metrics, signal_cb, + logits_cb=logits_cb) outputs = _tm_dict_to_torch_dict(outputs) @@ -774,6 +792,10 @@ async def async_stream_infer(self, for f in extra_fs: f(output, seq_len) + if finish and ppl_state is not None: + output.ppl_loss = ppl_state[0] + output.ppl_count = ppl_state[1] + prev_len = seq_len yield output @@ -828,5 +850,6 @@ def _get_generation_config(self, cfg: GenerationConfig): c.output_logprobs = cfg.logprobs if cfg.random_seed is not None: c.random_seed = cfg.random_seed - # print (c) + if cfg.output_ppl: + c.compute_ppl = True return c diff --git a/src/turbomind/engine/model_request.cc b/src/turbomind/engine/model_request.cc index 3380d5047c..4d9ba94479 100644 --- a/src/turbomind/engine/model_request.cc +++ b/src/turbomind/engine/model_request.cc @@ -82,7 +82,7 @@ auto ModelRequest::Forward(InputParam param, std::function cb) -> Output add(outputs_, "output_ids", data_type_v, kCPU, max_seq_len); add(outputs_, "sequence_length", data_type_v, kCPU, 1); - if (param.gen_cfg.output_logits) { + if (param.gen_cfg.output_logits && !param.gen_cfg.compute_ppl) { const int len = param.gen_cfg.output_logits == GenerationConfig::kAll ? max_in_out_len : max_out_len; add(outputs_, "logits", data_type_, kCPU, len, vocab_size_); } @@ -124,6 +124,7 @@ auto ModelRequest::Forward(InputParam param, std::function cb) -> Output r->gen_cfg = param.gen_cfg; r->stream_output = param.stream_output; r->forward_cb = std::move(cb); + r->logits_cb = std::move(param.logits_cb); r->state = state; r->metrics = metrics; diff --git a/src/turbomind/engine/model_request.h b/src/turbomind/engine/model_request.h index 0944429947..c64edc9ef4 100644 --- a/src/turbomind/engine/model_request.h +++ b/src/turbomind/engine/model_request.h @@ -33,6 +33,8 @@ class ModelRequest { bool stream_output; bool enable_metrics; + + std::function logits_cb; }; struct OutputParam { diff --git a/src/turbomind/engine/request.h b/src/turbomind/engine/request.h index 1c1f226412..fe07929da3 100644 --- a/src/turbomind/engine/request.h +++ b/src/turbomind/engine/request.h @@ -48,6 +48,8 @@ struct GenerationConfig { }; int output_last_hidden_state = 0; int output_logits = 0; + + bool compute_ppl = false; }; std::ostream& operator<<(std::ostream& os, const GenerationConfig& c); @@ -129,6 +131,9 @@ struct Request { std::shared_ptr grammar; std::shared_ptr matcher; + + // PPL callback: invoked with (logits_data_ptr, vocab_size, begin_offset, count, dtype) + std::function logits_cb; }; void UpdateState(Request& r, int status, int seq_len); @@ -197,6 +202,7 @@ void serdes(Archive& ar, GenerationConfig& g) ar & g.output_logprobs; ar & g.output_last_hidden_state; ar & g.output_logits; + ar & g.compute_ppl; // clang-format on } diff --git a/src/turbomind/models/output_processor.cc b/src/turbomind/models/output_processor.cc index 92f943de40..190c1711a5 100644 --- a/src/turbomind/models/output_processor.cc +++ b/src/turbomind/models/output_processor.cc @@ -74,8 +74,17 @@ struct OutputProcessor::Impl { auto& c = *rc[i]; auto& r = *c.req; auto& g = r.gen_cfg; - if (g.output_logits) { - c.output_logits = g.output_logits == kAll ? Interval{c.step0} : Interval{c.prompt_len - 1}; + if (g.output_logits || g.compute_ppl) { + if (g.compute_ppl && !g.output_logits) { + // PPL only needs prompt logits, not decode logits + c.output_logits = Interval{c.step0, Interval::Size{c.prompt_len - c.step0}}; + } + else if (g.output_logits == kAll) { + c.output_logits = Interval{c.step0}; + } + else { + c.output_logits = Interval{c.prompt_len - 1}; + } c.logits_offset = c.output_logits.begin(); } if (g.output_last_hidden_state) { @@ -182,28 +191,49 @@ struct OutputProcessor::Impl { { const int step_size = max_logits_len_; - // Coroutine frame - int p = 0; - auto ranges = data.output_logits; + // Split ranges into PPL callback requests and normal copy requests + decltype(data.output_logits) ppl_ranges; + decltype(data.output_logits) copy_ranges; + for (auto& entry : data.output_logits) { + auto& [i, t, src, dst] = entry; + if (rs[i]->req->logits_cb) { + ppl_ranges.push_back(entry); + } + else { + copy_ranges.push_back(entry); + } + } + + // Coroutine frame for normal copy + int p = 0; using Size = Interval::Size; - bool success = false; + bool copy_success = copy_ranges.empty(); // Erode the range iteratively until empty for (auto r = data.full_logits; r; r = -step_size | r) { - // dbg(&r); if (auto chunk = r & Interval{r.begin(), Size{step_size}}) { - // dbg(&chunk); - // Compute & output full logits by chunks auto logits = lm_head_(h.slice(chunk.begin(), (int)chunk.size())); - success = OutputLogitsImpl(ranges, p, logits, chunk.begin(), 2, rs); - if (success) { // all requests satisfied, exit early - break; + + // Invoke PPL callbacks + if (tp_rank_ == 0) { + const auto stream = core::Context::stream().handle(); + cudaStreamSynchronize(stream); + for (auto& [i, t, src, dst] : ppl_ranges) { + if (t == 2) { + auto& cb = rs[i]->req->logits_cb; + cb(logits.raw_data(), vocab_size_, chunk.begin(), (int)chunk.size(), logits.dtype()); + } + } + } + + if (!copy_success) { + copy_success = OutputLogitsImpl(copy_ranges, p, logits, chunk.begin(), 2, rs); } } } - TM_CHECK(success); // all requests must be satisfied at the end + TM_CHECK(copy_ranges.empty() || copy_success); } template diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index 084cc026dc..b73eae0b50 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -1,6 +1,7 @@ // Copyright (c) OpenMMLab. All rights reserved. #include +#include #include #include @@ -348,6 +349,7 @@ PYBIND11_MODULE(_turbomind, m) .def_readwrite("output_logprobs", &ft::GenerationConfig::output_logprobs) .def_readwrite("output_last_hidden_state", &ft::GenerationConfig::output_last_hidden_state) .def_readwrite("output_logits", &ft::GenerationConfig::output_logits) + .def_readwrite("compute_ppl", &ft::GenerationConfig::compute_ppl) .def("__repr__", [](const ft::GenerationConfig& c) { std::ostringstream oss; oss << c; @@ -448,13 +450,14 @@ PYBIND11_MODULE(_turbomind, m) py::class_(m, "ModelRequest") .def( "forward", - [](ModelRequest* model_request, - std::shared_ptr input_tensors, - const ft::SessionParam& session, - const ft::GenerationConfig& gen_cfg, - bool stream_output, - bool enable_metrics, - std::function cb) { + [](ModelRequest* model_request, + std::shared_ptr input_tensors, + const ft::SessionParam& session, + const ft::GenerationConfig& gen_cfg, + bool stream_output, + bool enable_metrics, + std::function cb, + std::optional> logits_cb) { ModelRequest::InputParam param{}; param.tensors = std::move(input_tensors); param.session = session; @@ -462,6 +465,36 @@ PYBIND11_MODULE(_turbomind, m) param.stream_output = stream_output; param.enable_metrics = enable_metrics; + if (logits_cb) { + auto py_cb = std::move(*logits_cb); + param.logits_cb = [py_cb = std::move(py_cb)]( + void* data, int vocab_size, int begin, int count, ft::DataType dtype) { + py::gil_scoped_acquire gil; + + int device_id = 0; + cudaGetDevice(&device_id); + Tensor t(data, {count, vocab_size}, dtype, ft::core::Device{ft::kDEVICE, device_id}); + DLManagedTensor* dlmt = TritonTensorToDLManagedTensor(t); + + py::capsule cap(dlmt, kDlTensorCapsuleName, [](PyObject* obj) { + DLManagedTensor* dlmt = + static_cast(PyCapsule_GetPointer(obj, kDlTensorCapsuleName)); + if (dlmt) { + dlmt->deleter(dlmt); + } + else { + PyErr_Clear(); + } + }); + + py::object torch = py::module_::import("torch"); + py::object from_dlpack = torch.attr("from_dlpack"); + py::object tensor = from_dlpack(cap); + + py_cb(tensor, vocab_size, begin, count); + }; + } + auto ret = model_request->Forward(std::move(param), [cb = std::move(cb)]() { try { cb(); @@ -478,7 +511,8 @@ PYBIND11_MODULE(_turbomind, m) "gen_cfg"_a, "stream_output"_a, "enable_metrics"_a, - "cb"_a) + "cb"_a, + "logits_cb"_a = py::none()) .def( "cancel", [](ModelRequest* model_request) {