diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index d6cd1a3329..4fba49b089 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -101,7 +101,7 @@ class GenerationConfig: """ n: int = 1 - max_new_tokens: int = 512 + max_new_tokens: int = None do_sample: bool = False top_p: float = 1.0 top_k: int = 50 @@ -112,7 +112,7 @@ class GenerationConfig: random_seed: int = None stop_words: list[str] = None bad_words: list[str] = None - stop_token_ids: list[int] = None + stop_token_ids: list[int] | list[list[int]] = None bad_token_ids: list[int] = None min_new_tokens: int = None skip_special_tokens: bool = True @@ -136,46 +136,91 @@ class GenerationConfig: repetition_ngram_size: int = 0 repetition_ngram_threshold: int = 0 + @staticmethod + def _normalize_stop_token_ids(ids: list[int] | list[list[int]] | None) -> list[list[int]]: + """Normalize stop_token_ids to list[list[int]].""" + if ids is None: + return [] + out: list[list[int]] = [] + for item in ids: + if isinstance(item, int): + out.append([item]) + else: + out.append(list(item)) + return out + def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer): - """Convert stop_words/bad_sords to ids and append the ids to + """Convert stop_words/bad_words to ids and append the ids to stop_token_ids/bad_token_ids.""" - def special_word_token_ids(words): - if words is not None: - assert isinstance(words, list) and \ - all(isinstance(elem, str) for elem in words), \ - f'stop_words must be a list of str but got {type(words)}' - indexes = [] - for word in words: - indexes += tokenizer.indexes_containing_token(word) - return indexes - return None - - stop_token_ids = special_word_token_ids(self.stop_words) or [] - bad_token_ids = special_word_token_ids(self.bad_words) or [] - stop_token_ids.extend(self.stop_token_ids or []) - bad_token_ids.extend(self.bad_token_ids or []) - self.stop_token_ids = list(set(stop_token_ids)) or None - self.bad_token_ids = list(set(bad_token_ids)) or None + def words_to_token_seqs(words: list[str]) -> list[list[int]]: + assert isinstance(words, list) and \ + all(isinstance(elem, str) for elem in words), \ + f'stop_words must be a list of str but got {type(words)}' + seqs: list[list[int]] = [] + for word in words: + single_matches = tokenizer.indexes_containing_token(word) + if single_matches: + for idx in single_matches: + seqs.append([idx]) + else: + encoded = tokenizer.encode(word, add_bos=False) + if encoded: + seqs.append(encoded) + return seqs + + stop_seqs = words_to_token_seqs(self.stop_words) if self.stop_words else [] + bad_seqs = words_to_token_seqs(self.bad_words) if self.bad_words else [] + + stop_seqs.extend(self._normalize_stop_token_ids(self.stop_token_ids)) + bad_seqs.extend([[i] for i in (self.bad_token_ids or [])]) + + # deduplicate stop_token_ids and bad_token_ids + seen = set() + deduped: list[list[int]] = [] + for seq in stop_seqs: + key = tuple(seq) + if key not in seen: + seen.add(key) + deduped.append(seq) + self.stop_token_ids = deduped or None + + seen_bad = set() + deduped_bad: list[int] = [] + for seq in bad_seqs: + if len(seq) > 1: + logger.warning(f'Multi-token bad word {seq} is not supported and ' + 'will be ignored. Only single-token bad words can be ' + 'masked in logits processing.') + continue + if seq[0] not in seen_bad: + seen_bad.add(seq[0]) + deduped_bad.append(seq[0]) + self.bad_token_ids = deduped_bad or None def update_from_hf_gen_cfg(self, generation_config, tokenizer_eos_token_id): """Update the stop_token_ids.""" - stop_token_ids = set(self.stop_token_ids or []) + stop_seqs = self._normalize_stop_token_ids(self.stop_token_ids) + existing = {tuple(s) for s in stop_seqs} + + def _add_single(tok_id: int): + key = (tok_id, ) + if key not in existing: + existing.add(key) + stop_seqs.append([tok_id]) - # add tokenizer's eos_token_id if tokenizer_eos_token_id is not None: - stop_token_ids.add(tokenizer_eos_token_id) + _add_single(tokenizer_eos_token_id) - # add eos_token_id from model's generation_config.json file if there - # is any. eos_token_id = generation_config.get('eos_token_id') if eos_token_id is not None: if isinstance(eos_token_id, int): - stop_token_ids.add(eos_token_id) + _add_single(eos_token_id) else: - stop_token_ids.update(eos_token_id) + for eid in eos_token_id: + _add_single(eid) - self.stop_token_ids = list(stop_token_ids) + self.stop_token_ids = stop_seqs def __post_init__(self): """Check input validation.""" @@ -185,6 +230,8 @@ def __post_init__(self): assert self.temperature >= 0 and self.temperature <= 2 # [0,2] assert 0 <= self.min_p <= 1, \ f'min_p should be in range [0, 1], but found {self.min_p}' + if self.stop_token_ids is not None: + self.stop_token_ids = self._normalize_stop_token_ids(self.stop_token_ids) @pydantic_dataclass diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index f6f290fc29..a355b061f9 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -210,7 +210,7 @@ def _filter_repetition_ngram_( return scores # use first stop words _, found = ngram(generated_ids, n, threshold, max_n, max_ngram_window_size) - stop_words = stop_words[:, 0] + stop_words = stop_words[:, 0, 0] # fill all scores -inf scores.masked_fill_(found[:, None], -float('inf')) # set stop words to 0 @@ -245,7 +245,7 @@ class SamplingInputs: bad_words: torch.LongTensor = None bad_mask: torch.BoolTensor = None stop_words: torch.LongTensor = None - stop_mask: torch.BoolTensor = None + stop_word_lens: torch.LongTensor = None repetition_penalty: torch.Tensor = None top_k: torch.LongTensor = None top_p: torch.Tensor = None @@ -428,11 +428,13 @@ async def __call__(self, scores: torch.Tensor) -> torch.Tensor: scores = _process_bad_words_(scores, bad_words, bad_mask) stop_words = sampling_inputs.stop_words - if stop_words is not None: + stop_word_lens = sampling_inputs.stop_word_lens + if stop_words is not None and stop_word_lens is not None: ignore_eos = sampling_inputs.num_ignore_eos > 0 - stop_mask = sampling_inputs.stop_mask - stop_mask = torch.where(ignore_eos[:, None], stop_mask, False) - scores = _process_bad_words_(scores, stop_words, stop_mask) + single_mask = (stop_word_lens == 1) & ignore_eos[:, None] + if single_mask.any(): + single_tokens = stop_words[:, :, 0] + scores = _process_bad_words_(scores, single_tokens, single_mask) return scores, logprobs diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index 6db690ab7f..5d4b4c932a 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -701,11 +701,15 @@ async def _step_postprocess_with_output(self, logger.debug(f' rank[{rank}]: synchronize token ids') # stopping criteria + # Use output_token_ids (all tokens accepted this step) so that multi-token + # stop sequences whose last token is not the final spec-decoded token are + # detected correctly. For non-spec AR, output_token_ids == next_token_ids. stopped, stop_pos, stopping_criteria = stopping_criteria.step( - next_token_ids, + output_token_ids, sampling_inputs.stop_words, inputs=inputs, extra_inputs=extra_inputs, + stop_word_lens=sampling_inputs.stop_word_lens, ) # send output diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 218cfb5113..bf20195dd1 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -52,7 +52,7 @@ class SamplingParam: repetition_penalty: float = 1.0 ignore_eos: bool = False random_seed: int = None - stop_words: list[int] = field(default_factory=list) + stop_words: list[list[int]] = field(default_factory=list) bad_words: list[int] = field(default_factory=list) max_new_tokens: int = 512 min_new_tokens: int = 0 @@ -75,7 +75,11 @@ def from_gen_config(cls, gen_config: GenerationConfig): stop_words = gen_config.stop_token_ids or [] bad_words = gen_config.bad_token_ids or [] if gen_config.ignore_eos: - bad_words += stop_words + if any(len(s) > 1 for s in stop_words): + logger.warning('Multi-token stop words are not supported and ' + 'will be ignored. Only single-token stop words can ' + 'be used to stop generation.') + bad_words += [s[0] for s in stop_words if len(s) == 1] stop_words = [] top_k = gen_config.top_k diff --git a/lmdeploy/pytorch/strategies/ar/model_agent.py b/lmdeploy/pytorch/strategies/ar/model_agent.py index 94e4baabdf..72cd0b7b1c 100644 --- a/lmdeploy/pytorch/strategies/ar/model_agent.py +++ b/lmdeploy/pytorch/strategies/ar/model_agent.py @@ -61,41 +61,195 @@ class ARExtraOutputs(ExtraOutputs): @dataclass class ARStoppingCriteria(StoppingCriteria): num_appendable_ids: torch.Tensor + # Tail of previously generated tokens, shape [batch, tail_len]. + # Maintained across steps so that multi-token stop sequences spanning two + # decode steps are detected without relying on the (pipelined) generated_ids + # from SamplingInputs, which lags one step behind. + stop_tail: torch.Tensor | None = None def clone(self): """clone.""" - return ARStoppingCriteria(num_appendable_ids=self.num_appendable_ids) + tail = self.stop_tail.clone() if self.stop_tail is not None else None + return ARStoppingCriteria(num_appendable_ids=self.num_appendable_ids, stop_tail=tail) def merge(self, other: 'ARStoppingCriteria'): """Merge two stopping criteria.""" - new_num_appendable = torch.cat([self.num_appendable_ids, other.num_appendable_ids], dim=0) - return ARStoppingCriteria(num_appendable_ids=new_num_appendable) + new_num = torch.cat([self.num_appendable_ids, other.num_appendable_ids], dim=0) + t0, t1 = self.stop_tail, other.stop_tail + if t0 is None and t1 is None: + new_tail = None + else: + bs0 = self.num_appendable_ids.size(0) + bs1 = other.num_appendable_ids.size(0) + dev = (t0 if t0 is not None else t1).device + if t0 is None: + t0 = torch.zeros(bs0, t1.size(1), dtype=torch.long, device=dev) + if t1 is None: + t1 = torch.zeros(bs1, t0.size(1), dtype=torch.long, device=dev) + # Pad the shorter tail to the same length. + l0, l1 = t0.size(1), t1.size(1) + if l0 < l1: + t0 = torch.nn.functional.pad(t0, (l1 - l0, 0)) + elif l1 < l0: + t1 = torch.nn.functional.pad(t1, (l0 - l1, 0)) + new_tail = torch.cat([t0, t1], dim=0) + return ARStoppingCriteria(num_appendable_ids=new_num, stop_tail=new_tail) def update(self, delta: ModelInputsDelta): """Update stopping criteria.""" indices = delta.indices - new_num_appendable = self.num_appendable_ids[indices] - return ARStoppingCriteria(num_appendable_ids=new_num_appendable) + new_num = self.num_appendable_ids[indices] + new_tail = self.stop_tail[indices] if self.stop_tail is not None else None + return ARStoppingCriteria(num_appendable_ids=new_num, stop_tail=new_tail) @record_function('stopping_criteria') def step(self, token_ids: torch.Tensor, stop_words: torch.Tensor, inputs: ModelInputs | None = None, - extra_inputs: ARExtraInputs | None = None): + extra_inputs: ARExtraInputs | None = None, + stop_word_lens: torch.Tensor | None = None): """Check whether to stop generation.""" num_appendable_ids = self.num_appendable_ids - 1 stopped = num_appendable_ids <= 0 stop_pos = torch.zeros_like(num_appendable_ids) - if stop_words is not None: - sw_stopped = (token_ids[:, None] == stop_words).any(1) + + if stop_words is None or stop_word_lens is None: + new_tail = None + else: + # Set a uniform shape for token_ids for both single and multi-token stop words + token_ids = token_ids.unsqueeze(1) if token_ids.ndim == 1 else token_ids + + sw_stopped, stop_pos, new_tail = self._check_stop_words(token_ids, stop_words, stop_word_lens) + stopped = stopped | sw_stopped one_ids = torch.clamp_max(num_appendable_ids, 0) num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) - # I don't know why assign inplace does not works... - new_stopping = ARStoppingCriteria(num_appendable_ids=num_appendable_ids) - return stopped, stop_pos, new_stopping + return (stopped, stop_pos, ARStoppingCriteria(num_appendable_ids=num_appendable_ids, stop_tail=new_tail)) + + def _check_stop_words(self, token_ids: torch.Tensor, stop_words: torch.Tensor, stop_word_lens: torch.Tensor): + """Vectorized multi-token stop word detection. + + Args: + token_ids: [batch, step_len], -1 for invalid positions. + Modified **in-place** (tokens after stop are set to -1). + stop_words: [batch, num_seqs, max_slen] + stop_word_lens: [batch, num_seqs] + + Returns: + sw_stopped: [batch] bool + stop_pos: [batch] long – step-relative index of the stop token + new_tail: [batch, tail_len] or None + """ + max_slen = int(stop_word_lens.max().item()) + + if max_slen <= 1: + # Fast path when every stop word is a single token + return self._check_single_stop_words(token_ids, stop_words, stop_word_lens) + + # General path for multi-token stop words + return self._check_multi_stop_words(token_ids, stop_words, stop_word_lens, max_slen) + + def _check_single_stop_words(self, token_ids: torch.Tensor, stop_words: torch.Tensor, stop_word_lens: torch.Tensor): + """Fast path: every stop word is a single token, AR always has L==1.""" + batch_size = token_ids.size(0) + device = token_ids.device + targets = stop_words[:, :, 0] # [B, S] + valid = (stop_word_lens == 1) # [B, S] + # token_ids [B, 1] broadcasts against targets [B, S] + sw_stopped = ((token_ids == targets) & valid).any(1) # [B] + stop_pos = torch.zeros(batch_size, dtype=torch.long, device=device) + return sw_stopped, stop_pos, None + + def _check_multi_stop_words(self, token_ids: torch.Tensor, stop_words: torch.Tensor, stop_word_lens: torch.Tensor, + max_slen: int): + """General path for multi-token stop words. + + Per-length unfold loop (each length needs its own window count), but + iterates ``range(1, max_slen+1)`` instead of calling the GPU-syncing + ``stop_word_lens.unique().tolist()``. + """ + tail_len = max_slen - 1 + batch_size = token_ids.size(0) + step_len = token_ids.size(1) + device = token_ids.device + + # -- 1. build history = [prev_tail | token_ids] -- + prev_tail = self._get_prev_tail(batch_size, tail_len, device) + if prev_tail is not None: + history = torch.cat([prev_tail, token_ids], dim=1) + else: + history = token_ids + hist_len = history.size(1) + + # -- 2. sliding-window matching per length -- + NO_MATCH = hist_len + best_end = history.new_full((batch_size, ), NO_MATCH) + for slen in range(1, max_slen + 1): + if hist_len < slen: + continue + windows = history.unfold(1, slen, 1) # [B, W, slen] + targets = stop_words[:, :, :slen] # [B, S, slen] + len_mask = (stop_word_lens == slen) # [B, S] + + match = (windows.unsqueeze(2) == targets.unsqueeze(1)).all(-1) + match = match & len_mask.unsqueeze(1) + match_any = match.any(2) # [B, W] + + min_win = max(0, tail_len - slen + 1) + if min_win > 0: + match_any[:, :min_win] = False + + has_match = match_any.any(1) + first_win = match_any.int().argmax(1) + end_pos = first_win + slen - 1 + better = has_match & (end_pos < best_end) + best_end = torch.where(better, end_pos, best_end) + + sw_stopped = best_end < NO_MATCH + + # -- 3. compute stop_pos and mask trailing tokens -- + step_stop_pos = best_end - tail_len + stop_pos = torch.where(sw_stopped, step_stop_pos, sw_stopped.new_zeros(batch_size, dtype=torch.long)) + + col_idx = torch.arange(step_len, device=device) + after_stop = (col_idx > step_stop_pos.unsqueeze(1)) & sw_stopped.unsqueeze(1) + token_ids[after_stop] = -1 + + # -- 4. update tail -- + new_tail = self._build_new_tail(history, tail_len, sw_stopped, best_end, token_ids) + + return sw_stopped, stop_pos, new_tail + + def _get_prev_tail(self, batch_size: int, tail_len: int, device: torch.device) -> torch.Tensor | None: + """Return the previous tail padded/trimmed to ``tail_len``.""" + if tail_len <= 0: + return None + if self.stop_tail is None: + return torch.zeros(batch_size, tail_len, dtype=torch.long, device=device) + prev = self.stop_tail.to(device) + pt_len = prev.size(1) + if pt_len < tail_len: + prev = torch.nn.functional.pad(prev, (tail_len - pt_len, 0), value=-1) + elif pt_len > tail_len: + prev = prev[:, -tail_len:] + return prev + + @staticmethod + def _build_new_tail(history: torch.Tensor, tail_len: int, sw_stopped: torch.Tensor, best_end: torch.Tensor, + token_ids: torch.Tensor) -> torch.Tensor | None: + """Gather the last ``tail_len`` valid tokens from *history*.""" + if tail_len <= 0: + return None + valid_counts = (token_ids >= 0).sum(1) + effective_end = torch.where(sw_stopped, best_end, tail_len + valid_counts - 1) + effective_end = effective_end.clamp(min=tail_len - 1) + + offsets = torch.arange(tail_len, device=history.device) + indices = (effective_end - tail_len + 1).unsqueeze(1) + offsets.unsqueeze(0) + indices = indices.clamp(min=0, max=history.size(1) - 1) + return history.gather(1, indices) class ARModelAgentStrategy(ModelAgentStrategy): diff --git a/lmdeploy/pytorch/strategies/ar/sampling.py b/lmdeploy/pytorch/strategies/ar/sampling.py index 8a8c45d032..95a566d680 100644 --- a/lmdeploy/pytorch/strategies/ar/sampling.py +++ b/lmdeploy/pytorch/strategies/ar/sampling.py @@ -97,7 +97,8 @@ def __gather_params(): bw = param.bad_words sw = param.stop_words if (not param.ignore_eos and seq.num_new_tokens < param.min_new_tokens): - bw = bw + sw + # During min_new_tokens period suppress single-token stops as bad words. + bw = bw + [s[0] for s in sw if len(s) == 1] bad_words[idx] = bw stop_words[idx] = sw logits_processors[idx] = param.logits_processors @@ -143,6 +144,24 @@ def __get_bad_words(bad_words): mask = ret >= 0 return ret, mask + def __get_stop_words(stop_words_list): + """Build stop_words [batch, num_seqs, max_len] and stop_word_lens + [batch, num_seqs].""" + max_num_seqs = max(len(sw) for sw in stop_words_list) + if max_num_seqs == 0: + return None, None + max_len = max((len(s) for sw in stop_words_list for s in sw), default=0) + if max_len == 0: + return None, None + seqs = torch.zeros((batch_size, max_num_seqs, max_len), dtype=torch.long) + lens = torch.zeros((batch_size, max_num_seqs), dtype=torch.long) + for i, sw in enumerate(stop_words_list): + for j, seq in enumerate(sw): + slen = len(seq) + seqs[i, j, :slen] = torch.tensor(seq, dtype=torch.long) + lens[i, j] = slen + return seqs, lens + __gather_params() if all(rp == 1.0 for rp in repetition_penalty): @@ -156,7 +175,7 @@ def __get_bad_words(bad_words): temperature = None bad_words, bad_mask = __get_bad_words(bad_words) - stop_words, stop_mask = __get_bad_words(stop_words) + stop_words, stop_word_lens = __get_stop_words(stop_words) max_top_k = max(top_k) if min(top_k) <= 0: @@ -201,7 +220,7 @@ def __get_bad_words(bad_words): bad_words=bad_words, bad_mask=bad_mask, stop_words=stop_words, - stop_mask=stop_mask, + stop_word_lens=stop_word_lens, repetition_penalty=repetition_penalty, top_k=top_k, top_p=top_p, diff --git a/lmdeploy/pytorch/strategies/dllm/sampling.py b/lmdeploy/pytorch/strategies/dllm/sampling.py index ab5174f017..67b6358bc3 100644 --- a/lmdeploy/pytorch/strategies/dllm/sampling.py +++ b/lmdeploy/pytorch/strategies/dllm/sampling.py @@ -33,7 +33,7 @@ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: 'bad_words', 'bad_mask', 'stop_words', - 'stop_mask', + 'stop_word_lens', 'repetition_penalty', 'top_k', 'top_p', diff --git a/lmdeploy/serve/core/async_engine.py b/lmdeploy/serve/core/async_engine.py index 114775fc69..dedb5068b4 100644 --- a/lmdeploy/serve/core/async_engine.py +++ b/lmdeploy/serve/core/async_engine.py @@ -39,6 +39,53 @@ logger = get_logger('lmdeploy') +def _commit_stream_tokens(all_ids: list[int], input_len: int, pending_ids: list[int], new_ids: list[int], + multi_stop_seqs: list[list[int]], holdback_len: int): + """Commit safe streamed tokens with multi-stop holdback. + + Args: + all_ids: full token_ids buffer (prompt + committed generated tokens). + input_len: length of the prompt prefix in all_ids; generated tokens start here. + pending_ids: tokens buffered but not yet committed (mutated in place). + new_ids: freshly arrived tokens to add this iteration. + multi_stop_seqs: list of multi-token stop sequences to match against. + holdback_len: number of tokens to withhold from the tail (max_stop_len - 1). + + Returns: + hit_stop_token: matched stop length (0 if none). + matched_stop_ids: matched stop token ids. + commit_ids: token ids safe to emit this round. + pending_ids: updated pending buffer. + """ + pending_ids.extend(new_ids) + hit_stop_token = 0 + matched_stop_ids: list[int] = [] + for mseq in multi_stop_seqs: + slen = len(mseq) + plen = len(pending_ids) + total = len(all_ids) - input_len + plen + if total < slen: + continue + if plen >= slen: + tail = pending_ids[-slen:] + else: + need = slen - plen + tail = all_ids[len(all_ids) - need:] + pending_ids + if tail == mseq: + hit_stop_token = slen + matched_stop_ids = list(mseq) + del pending_ids[-slen:] + break + if hit_stop_token: + commit_ids = list(pending_ids) + pending_ids.clear() + else: + commit_len = max(0, len(pending_ids) - holdback_len) + commit_ids = pending_ids[:commit_len] + del pending_ids[:commit_len] + return hit_stop_token, matched_stop_ids, commit_ids, pending_ids + + @dataclasses.dataclass class GenOut: """Pack all response information together.""" @@ -144,8 +191,6 @@ def __init__(self, # parameters for member functions self.stop_words = _stop_words(self.chat_template.stop_words, self.tokenizer) - if self.stop_words is not None: - self.stop_words = self.stop_words[0][0].tolist() self.backend = backend self.request_logger = RequestLogger(max_log_len) @@ -356,6 +401,8 @@ async def generate( else: logger.warning('chat_template_kwargs["enable_thinking"] is already set, ' 'the value will not be overwritten by enable_thinking') + + gen_config = self._determine_gen_config(session, input_ids, gen_config=gen_config) if messages: prompt = messages self.request_logger.log_prompt(session, prompt=prompt) @@ -381,8 +428,6 @@ async def generate( # Figure out a graceful way to handle the invalid input prompt_input = dict(input_ids=input_ids) - gen_config = self._determine_gen_config(session, input_ids, gen_config=gen_config) - if gen_config.max_new_tokens == 0: logger.info(f'run out of tokens. session={session_id}.') yield GenOut(response='', @@ -416,9 +461,14 @@ async def generate( def is_error(status): return status not in [ResponseType.SUCCESS, ResponseType.FINISH, ResponseType.CANCEL] - stop_ids = [] + single_stop_ids: set = set() + multi_stop_seqs: list = [] if not gen_config.ignore_eos: - stop_ids = gen_config.stop_token_ids or [] + for seq in (gen_config.stop_token_ids or []): + if len(seq) == 1: + single_stop_ids.add(seq[0]) + else: + multi_stop_seqs.append(seq) metrics_processor.increase_total_requests() async with session.request_handle() as handle: @@ -451,6 +501,11 @@ def is_error(status): step=history_len) as gen: logger.debug(f'[generate] session {session_id} started') hit_stop_token = 0 + stop_by_single = False + matched_stop_ids: list[int] = [] + max_multi_stop_len = max((len(s) for s in multi_stop_seqs), default=0) + holdback_len = max(0, max_multi_stop_len - 1) + pending_ids: list[int] = [] req_stats = RequestStats(prompt_tokens=input_len) # per-request stats # We use this as default outputs in case the async_stream_infer of the Engine yields empty generator. @@ -468,12 +523,32 @@ def is_error(status): output_len = len(outputs.token_ids) if hit_stop_token or output_len == 0: continue - - # This assumes the engine will stop when stop token is hit - if output_len and outputs.token_ids[-1] in stop_ids: + # print(f'outputs.token_ids: {outputs.token_ids}') + # Check single-token stop + if output_len and outputs.token_ids[-1] in single_stop_ids: hit_stop_token = 1 - - token_ids += outputs.token_ids[:output_len - hit_stop_token] + stop_by_single = True + matched_stop_ids = [outputs.token_ids[-1]] + + new_ids = outputs.token_ids[:output_len - hit_stop_token] + if not hit_stop_token: + if multi_stop_seqs: + hit_stop_token, matched_multi_ids, commit_ids, pending_ids = _commit_stream_tokens( + token_ids, + input_len, + pending_ids, + new_ids, + multi_stop_seqs, + holdback_len, + ) + if matched_multi_ids: + matched_stop_ids = matched_multi_ids + else: + commit_ids = new_ids + else: + commit_ids = pending_ids + new_ids + pending_ids = [] + token_ids.extend(commit_ids) gen_len = len(token_ids) - input_len ids_offset = state.ids_offset @@ -507,7 +582,18 @@ def is_error(status): if outputs.status == ResponseType.CANCEL: finish_reason = 'abort' else: - finish_reason = 'stop' if outputs.token_ids[-1] in stop_ids else 'length' + if not hit_stop_token and pending_ids: + token_ids.extend(pending_ids) + pending_ids = [] + gen_len = len(token_ids) - input_len + ids_offset = state.ids_offset + response, state = self.tokenizer.detokenize_incrementally( + token_ids, + state, + skip_special_tokens=gen_config.skip_special_tokens, + spaces_between_special_tokens=gen_config.spaces_between_special_tokens) + is_stop = stop_by_single or hit_stop_token > 0 + finish_reason = 'stop' if is_stop else 'length' # utf-8 char at the end means it's a potential unfinished byte sequence if not response.endswith('�'): @@ -515,13 +601,14 @@ def is_error(status): response = '' token_ids, logits, last_hidden_state, logprobs = [], None, None, None if gen_config.include_stop_str_in_output and finish_reason == 'stop': - # return the eos token id (MUST be in a list), eos string, eos token's logits and so on - token_ids = outputs.token_ids[-1:] + token_ids = matched_stop_ids if matched_stop_ids else outputs.token_ids[-1:] + stop_len = len(token_ids) response = self.tokenizer.decode(token_ids, skip_special_tokens=False) - logits = outputs.logits[-1:] if outputs.logits is not None else None - last_hidden_state = outputs.last_hidden_state[-1:] if outputs.last_hidden_state else None - logprobs = outputs.logprobs[-1:] if outputs.logprobs else None - gen_len += 1 + logits = outputs.logits[-stop_len:] if outputs.logits is not None else None + last_hidden_state = (outputs.last_hidden_state[-stop_len:] + if outputs.last_hidden_state else None) + logprobs = outputs.logprobs[-stop_len:] if outputs.logprobs else None + gen_len += stop_len # router replay routed_experts = outputs.routed_experts diff --git a/lmdeploy/tokenizer.py b/lmdeploy/tokenizer.py index c184e53111..e33af33b7d 100644 --- a/lmdeploy/tokenizer.py +++ b/lmdeploy/tokenizer.py @@ -182,9 +182,9 @@ def indexes_containing_token(self, token: str): # there might be token id that exceeds self.vocab_size if len(indexes) == 0: indexes = self.encode(token, False) - if len(indexes) != 1: - self.logger.warning(f'The token {token}, its length of indexes {indexes} is ' - 'not 1. Currently, it can not be used as stop words') + if len(indexes) > 1: + # Multi-token encoding: return empty so callers can handle + # the multi-token case via encode() directly. indexes = [] self._indexes_tokens_deque.append((token, indexes)) return indexes @@ -544,7 +544,7 @@ def indexes_containing_token(self, token): the input token.""" encoded = self.encode(token, add_bos=False) if len(encoded) > 1: - self.logger.warning(f'The token {token}, its length of indexes {encoded} is over ' - 'than 1. Currently, it can not be used as stop words') + # Multi-token encoding: return empty so callers can handle + # the multi-token case via encode() directly. return [] return self.model.indexes_containing_token(token) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index f95b2b93ca..e6e5360991 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -41,12 +41,20 @@ MAX_LOGPROBS = 1024 -def _construct_stop_or_bad_words(words: list[int] = None): - if words is None or len(words) == 0: +def _construct_stop_or_bad_words(seqs: list[list[int]] = None): + """Build packed (token_ids, offsets) for TurboMind stop/bad words. + + Each inner list is a token ID sequence. The offset array stores cumulative lengths so the C++ kernel knows where + each sequence ends. + """ + if not seqs: return None - offsets = list(range(1, len(words) + 1)) - combined = [words, offsets] - return combined + tokens: list[int] = [] + offsets: list[int] = [] + for seq in seqs: + tokens.extend(seq) + offsets.append(len(tokens)) + return [tokens, offsets] def _np_dict_to_tm_dict(np_dict: dict): @@ -807,9 +815,9 @@ def _get_generation_config(self, cfg: GenerationConfig): c.min_p = cfg.min_p c.temperature = cfg.temperature if cfg.stop_token_ids: - c.eos_ids = cfg.stop_token_ids + c.eos_ids = [s[0] for s in cfg.stop_token_ids if len(s) == 1] if cfg.bad_token_ids: - c.bad_ids = _construct_stop_or_bad_words(cfg.bad_token_ids) + c.bad_ids = _construct_stop_or_bad_words([[tid] for tid in cfg.bad_token_ids]) if not cfg.ignore_eos and cfg.stop_token_ids: c.stop_ids = _construct_stop_or_bad_words(cfg.stop_token_ids) c.repetition_penalty = cfg.repetition_penalty diff --git a/lmdeploy/utils.py b/lmdeploy/utils.py index 5e06ab5ae9..387f522760 100644 --- a/lmdeploy/utils.py +++ b/lmdeploy/utils.py @@ -195,28 +195,31 @@ def filter_suffix(response: str, suffixes: list[str] | None = None) -> str: return response -# TODO remove stop_word_offsets stuff and make it clean -def _stop_words(stop_words: list[int | str], tokenizer: object): - """Return list of stop-words to numpy.ndarray.""" - import numpy as np +def _stop_words(stop_words: list[int | str], tokenizer: object) -> list[list[int]] | None: + """Convert chat-template stop words to List[List[int]]. + + Each element is a token ID sequence representing one stop word. Single-token matches from vocab scan produce + length-1 lists. Multi-token words that require encoding produce longer lists. + """ if stop_words is None: return None assert isinstance(stop_words, list) and \ all(isinstance(elem, (str, int)) for elem in stop_words), \ f'stop_words must be a list but got {type(stop_words)}' - stop_indexes = [] + seqs: list[list[int]] = [] for stop_word in stop_words: - if isinstance(stop_word, str): - stop_indexes += tokenizer.indexes_containing_token(stop_word) - elif isinstance(stop_word, int): - stop_indexes.append(stop_word) - assert isinstance(stop_indexes, list) and all(isinstance(elem, int) for elem in stop_indexes), 'invalid stop_words' - # each id in stop_indexes represents a stop word - # refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for - # detailed explanation about fastertransformer's stop_indexes - stop_word_offsets = range(1, len(stop_indexes) + 1) - stop_words = np.array([[stop_indexes, stop_word_offsets]]).astype(np.int32) - return stop_words + if isinstance(stop_word, int): + seqs.append([stop_word]) + elif isinstance(stop_word, str): + single_matches = tokenizer.indexes_containing_token(stop_word) + if single_matches: + for idx in single_matches: + seqs.append([idx]) + else: + encoded = tokenizer.encode(stop_word, add_bos=False) + if encoded: + seqs.append(encoded) + return seqs or None def get_hf_gen_cfg(path: str):