-
Notifications
You must be signed in to change notification settings - Fork 682
Support multi stop words #4454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Support multi stop words #4454
Changes from 6 commits
f959a2f
9e58185
51dd33c
8874b4f
9985667
27bdb55
b7b086b
36aec3b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,7 +2,7 @@ | |||||||||||||||
| import enum | ||||||||||||||||
| import time | ||||||||||||||||
| from dataclasses import dataclass, field | ||||||||||||||||
| from typing import Any, Callable, Dict, List, Literal, Optional | ||||||||||||||||
| from typing import Any, Callable, Literal | ||||||||||||||||
|
|
||||||||||||||||
| import torch | ||||||||||||||||
| from pydantic.dataclasses import dataclass as pydantic_dataclass | ||||||||||||||||
|
|
@@ -100,7 +100,7 @@ class GenerationConfig: | |||||||||||||||
| """ | ||||||||||||||||
|
|
||||||||||||||||
| n: int = 1 | ||||||||||||||||
| max_new_tokens: int = 512 | ||||||||||||||||
| max_new_tokens: int = None | ||||||||||||||||
| do_sample: bool = False | ||||||||||||||||
| top_p: float = 1.0 | ||||||||||||||||
| top_k: int = 50 | ||||||||||||||||
|
|
@@ -109,24 +109,24 @@ class GenerationConfig: | |||||||||||||||
| repetition_penalty: float = 1.0 | ||||||||||||||||
| ignore_eos: bool = False | ||||||||||||||||
| random_seed: int = None | ||||||||||||||||
| stop_words: List[str] = None | ||||||||||||||||
| bad_words: List[str] = None | ||||||||||||||||
| stop_token_ids: List[int] = None | ||||||||||||||||
| bad_token_ids: List[int] = None | ||||||||||||||||
| stop_words: list[str] = None | ||||||||||||||||
| bad_words: list[str] = None | ||||||||||||||||
| stop_token_ids: list[int] | list[list[int]] = None | ||||||||||||||||
| bad_token_ids: list[int] = None | ||||||||||||||||
| min_new_tokens: int = None | ||||||||||||||||
| skip_special_tokens: bool = True | ||||||||||||||||
| spaces_between_special_tokens: bool = True | ||||||||||||||||
| logprobs: int = None | ||||||||||||||||
| response_format: Optional[Dict] = None | ||||||||||||||||
| logits_processors: Optional[List[LogitsProcessor]] = None | ||||||||||||||||
| response_format: dict | None = None | ||||||||||||||||
| logits_processors: list[LogitsProcessor] | None = None | ||||||||||||||||
| output_logits: Literal['all', 'generation'] = None | ||||||||||||||||
| output_last_hidden_state: Literal['all', 'generation'] = None | ||||||||||||||||
| include_stop_str_in_output: bool = False | ||||||||||||||||
|
|
||||||||||||||||
| # for disaggregation | ||||||||||||||||
| with_cache: bool = False | ||||||||||||||||
| preserve_cache: bool = False | ||||||||||||||||
| migration_request: Optional[MigrationRequest] = None | ||||||||||||||||
| migration_request: MigrationRequest | None = None | ||||||||||||||||
|
|
||||||||||||||||
| # router replay | ||||||||||||||||
| return_routed_experts: bool = False | ||||||||||||||||
|
|
@@ -135,46 +135,91 @@ class GenerationConfig: | |||||||||||||||
| repetition_ngram_size: int = 0 | ||||||||||||||||
| repetition_ngram_threshold: int = 0 | ||||||||||||||||
|
|
||||||||||||||||
| @staticmethod | ||||||||||||||||
| def _normalize_stop_token_ids(ids: list[int] | list[list[int]] | None) -> list[list[int]]: | ||||||||||||||||
| """Normalize stop_token_ids to list[list[int]].""" | ||||||||||||||||
| if ids is None: | ||||||||||||||||
| return [] | ||||||||||||||||
| out: list[list[int]] = [] | ||||||||||||||||
| for item in ids: | ||||||||||||||||
| if isinstance(item, int): | ||||||||||||||||
| out.append([item]) | ||||||||||||||||
| else: | ||||||||||||||||
| out.append(list(item)) | ||||||||||||||||
| return out | ||||||||||||||||
|
|
||||||||||||||||
| def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer): | ||||||||||||||||
| """Convert stop_words/bad_sords to ids and append the ids to | ||||||||||||||||
| """Convert stop_words/bad_words to ids and append the ids to | ||||||||||||||||
| stop_token_ids/bad_token_ids.""" | ||||||||||||||||
|
|
||||||||||||||||
| def special_word_token_ids(words): | ||||||||||||||||
| if words is not None: | ||||||||||||||||
| assert isinstance(words, List) and \ | ||||||||||||||||
| all(isinstance(elem, str) for elem in words), \ | ||||||||||||||||
| f'stop_words must be a list of str but got {type(words)}' | ||||||||||||||||
| indexes = [] | ||||||||||||||||
| for word in words: | ||||||||||||||||
| indexes += tokenizer.indexes_containing_token(word) | ||||||||||||||||
| return indexes | ||||||||||||||||
| return None | ||||||||||||||||
|
|
||||||||||||||||
| stop_token_ids = special_word_token_ids(self.stop_words) or [] | ||||||||||||||||
| bad_token_ids = special_word_token_ids(self.bad_words) or [] | ||||||||||||||||
| stop_token_ids.extend(self.stop_token_ids or []) | ||||||||||||||||
| bad_token_ids.extend(self.bad_token_ids or []) | ||||||||||||||||
| self.stop_token_ids = list(set(stop_token_ids)) or None | ||||||||||||||||
| self.bad_token_ids = list(set(bad_token_ids)) or None | ||||||||||||||||
| def words_to_token_seqs(words: list[str]) -> list[list[int]]: | ||||||||||||||||
| assert isinstance(words, list) and \ | ||||||||||||||||
| all(isinstance(elem, str) for elem in words), \ | ||||||||||||||||
| f'stop_words must be a list of str but got {type(words)}' | ||||||||||||||||
| seqs: list[list[int]] = [] | ||||||||||||||||
| for word in words: | ||||||||||||||||
| single_matches = tokenizer.indexes_containing_token(word) | ||||||||||||||||
| if single_matches: | ||||||||||||||||
| for idx in single_matches: | ||||||||||||||||
| seqs.append([idx]) | ||||||||||||||||
| else: | ||||||||||||||||
| encoded = tokenizer.encode(word, add_bos=False) | ||||||||||||||||
| if encoded: | ||||||||||||||||
| seqs.append(encoded) | ||||||||||||||||
| return seqs | ||||||||||||||||
|
|
||||||||||||||||
| stop_seqs = words_to_token_seqs(self.stop_words) if self.stop_words else [] | ||||||||||||||||
| bad_seqs = words_to_token_seqs(self.bad_words) if self.bad_words else [] | ||||||||||||||||
|
|
||||||||||||||||
| stop_seqs.extend(self._normalize_stop_token_ids(self.stop_token_ids)) | ||||||||||||||||
| bad_seqs.extend([[i] for i in (self.bad_token_ids or [])]) | ||||||||||||||||
|
|
||||||||||||||||
| # deduplicate stop_token_ids and bad_token_ids | ||||||||||||||||
| seen = set() | ||||||||||||||||
| deduped: list[list[int]] = [] | ||||||||||||||||
| for seq in stop_seqs: | ||||||||||||||||
| key = tuple(seq) | ||||||||||||||||
| if key not in seen: | ||||||||||||||||
| seen.add(key) | ||||||||||||||||
| deduped.append(seq) | ||||||||||||||||
| self.stop_token_ids = deduped or None | ||||||||||||||||
|
||||||||||||||||
| self.stop_token_ids = deduped or None | |
| # Preserve backward-compatible shape: use flat list[int] when only | |
| # single-token stop sequences are present; otherwise keep nested list[list[int]]. | |
| if deduped and all(len(seq) == 1 for seq in deduped): | |
| self.stop_token_ids = [seq[0] for seq in deduped] or None | |
| else: | |
| self.stop_token_ids = deduped or None |
Copilot
AI
Mar 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR changes GenerationConfig.stop_token_ids to be normalized/stored as list[list[int]] (to support multi-token stop sequences). There are existing unit tests (e.g. tests/test_lmdeploy/test_messages.py::test_engine_generation_config) that assert stop_token_ids is List[int], so the current test suite will fail and there are no new tests covering multi-token stop sequences. Please update the existing tests and add coverage for: (1) mixed stop_token_ids input ([int] and [[...]]), (2) multi-token stop_words encoding path, and (3) streaming stop behavior (holdback) if applicable.
Copilot
AI
Mar 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TurbomindEngineConfig defines devices twice (once at line 315 and again at line 331). In a dataclass, the later field overwrites the earlier one, which is error-prone and can lead to confusing defaults/type hints. Remove the duplicate and keep a single devices definition with the intended type/default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -695,11 +695,15 @@ async def _step_postprocess_with_output(self, | |
| logger.debug(f'<ForwardTask> rank[{rank}]: synchronize token ids') | ||
|
|
||
| # stopping criteria | ||
| # Use output_token_ids (all tokens accepted this step) so that multi-token | ||
| # stop sequences whose last token is not the final spec-decoded token are | ||
| # detected correctly. For non-spec AR, output_token_ids == next_token_ids. | ||
| stopped, stop_pos, stopping_criteria = stopping_criteria.step( | ||
| next_token_ids, | ||
| output_token_ids, | ||
| sampling_inputs.stop_words, | ||
| inputs=inputs, | ||
| extra_inputs=extra_inputs, | ||
| stop_word_lens=sampling_inputs.stop_word_lens, | ||
| ) | ||
|
Comment on lines
703
to
713
|
||
|
|
||
| # send output | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GenerationConfig.max_new_tokensnow defaults toNone, but other code paths constructGenerationConfig()and immediately treatmax_new_tokensas an int. For example,lmdeploy/pytorch/engine/engine_instance.pycreatesgen_config = gen_config or GenerationConfig()andSamplingParam.from_gen_config()then does numeric comparisons (max_new_tokens < 0, etc.), which will raise aTypeErrorwhenmax_new_tokensisNone. Either keep the default as an int (e.g., 512) or update all consumers (not justAsyncEngine._determine_gen_config) to handleNonesafely.