Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 75 additions & 28 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class GenerationConfig:
"""

n: int = 1
max_new_tokens: int = 512
max_new_tokens: int = None
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GenerationConfig.max_new_tokens now defaults to None, but other code paths construct GenerationConfig() and immediately treat max_new_tokens as an int. For example, lmdeploy/pytorch/engine/engine_instance.py creates gen_config = gen_config or GenerationConfig() and SamplingParam.from_gen_config() then does numeric comparisons (max_new_tokens < 0, etc.), which will raise a TypeError when max_new_tokens is None. Either keep the default as an int (e.g., 512) or update all consumers (not just AsyncEngine._determine_gen_config) to handle None safely.

Suggested change
max_new_tokens: int = None
max_new_tokens: int | None = 512

Copilot uses AI. Check for mistakes.
do_sample: bool = False
top_p: float = 1.0
top_k: int = 50
Expand All @@ -112,7 +112,7 @@ class GenerationConfig:
random_seed: int = None
stop_words: list[str] = None
bad_words: list[str] = None
stop_token_ids: list[int] = None
stop_token_ids: list[int] | list[list[int]] = None
bad_token_ids: list[int] = None
min_new_tokens: int = None
skip_special_tokens: bool = True
Expand All @@ -136,46 +136,91 @@ class GenerationConfig:
repetition_ngram_size: int = 0
repetition_ngram_threshold: int = 0

@staticmethod
def _normalize_stop_token_ids(ids: list[int] | list[list[int]] | None) -> list[list[int]]:
"""Normalize stop_token_ids to list[list[int]]."""
if ids is None:
return []
out: list[list[int]] = []
for item in ids:
if isinstance(item, int):
out.append([item])
else:
out.append(list(item))
return out

def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer):
"""Convert stop_words/bad_sords to ids and append the ids to
"""Convert stop_words/bad_words to ids and append the ids to
stop_token_ids/bad_token_ids."""

def special_word_token_ids(words):
if words is not None:
assert isinstance(words, list) and \
all(isinstance(elem, str) for elem in words), \
f'stop_words must be a list of str but got {type(words)}'
indexes = []
for word in words:
indexes += tokenizer.indexes_containing_token(word)
return indexes
return None

stop_token_ids = special_word_token_ids(self.stop_words) or []
bad_token_ids = special_word_token_ids(self.bad_words) or []
stop_token_ids.extend(self.stop_token_ids or [])
bad_token_ids.extend(self.bad_token_ids or [])
self.stop_token_ids = list(set(stop_token_ids)) or None
self.bad_token_ids = list(set(bad_token_ids)) or None
def words_to_token_seqs(words: list[str]) -> list[list[int]]:
assert isinstance(words, list) and \
all(isinstance(elem, str) for elem in words), \
f'stop_words must be a list of str but got {type(words)}'
seqs: list[list[int]] = []
for word in words:
single_matches = tokenizer.indexes_containing_token(word)
if single_matches:
for idx in single_matches:
seqs.append([idx])
else:
encoded = tokenizer.encode(word, add_bos=False)
if encoded:
seqs.append(encoded)
return seqs

stop_seqs = words_to_token_seqs(self.stop_words) if self.stop_words else []
bad_seqs = words_to_token_seqs(self.bad_words) if self.bad_words else []

stop_seqs.extend(self._normalize_stop_token_ids(self.stop_token_ids))
bad_seqs.extend([[i] for i in (self.bad_token_ids or [])])

# deduplicate stop_token_ids and bad_token_ids
seen = set()
deduped: list[list[int]] = []
for seq in stop_seqs:
key = tuple(seq)
if key not in seen:
seen.add(key)
deduped.append(seq)
self.stop_token_ids = deduped or None
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GenerationConfig.convert_stop_bad_words_to_ids() now mutates self.stop_token_ids to a nested list[list[int]]. Since GenerationConfig is a public config object (and previously exposed stop_token_ids as List[int]), this is a backward-incompatible behavioral change for downstream code that inspects or serializes stop_token_ids. Consider either (a) keeping stop_token_ids externally as list[int] when only single-token stops are present, (b) introducing a new field for multi-token stop sequences, or (c) clearly documenting this change and providing a helper accessor that returns the legacy flattened single-token view.

Suggested change
self.stop_token_ids = deduped or None
# Preserve backward-compatible shape: use flat list[int] when only
# single-token stop sequences are present; otherwise keep nested list[list[int]].
if deduped and all(len(seq) == 1 for seq in deduped):
self.stop_token_ids = [seq[0] for seq in deduped] or None
else:
self.stop_token_ids = deduped or None

Copilot uses AI. Check for mistakes.

Comment on lines +139 to +187
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR changes GenerationConfig.stop_token_ids to be normalized/stored as list[list[int]] (to support multi-token stop sequences). There are existing unit tests (e.g. tests/test_lmdeploy/test_messages.py::test_engine_generation_config) that assert stop_token_ids is List[int], so the current test suite will fail and there are no new tests covering multi-token stop sequences. Please update the existing tests and add coverage for: (1) mixed stop_token_ids input ([int] and [[...]]), (2) multi-token stop_words encoding path, and (3) streaming stop behavior (holdback) if applicable.

Copilot uses AI. Check for mistakes.
seen_bad = set()
deduped_bad: list[int] = []
for seq in bad_seqs:
if len(seq) > 1:
logger.warning(f'Multi-token bad word {seq} is not supported and '
'will be ignored. Only single-token bad words can be '
'masked in logits processing.')
continue
if seq[0] not in seen_bad:
seen_bad.add(seq[0])
deduped_bad.append(seq[0])
self.bad_token_ids = deduped_bad or None

def update_from_hf_gen_cfg(self, generation_config, tokenizer_eos_token_id):
"""Update the stop_token_ids."""
stop_token_ids = set(self.stop_token_ids or [])
stop_seqs = self._normalize_stop_token_ids(self.stop_token_ids)
existing = {tuple(s) for s in stop_seqs}

def _add_single(tok_id: int):
key = (tok_id, )
if key not in existing:
existing.add(key)
stop_seqs.append([tok_id])

# add tokenizer's eos_token_id
if tokenizer_eos_token_id is not None:
stop_token_ids.add(tokenizer_eos_token_id)
_add_single(tokenizer_eos_token_id)

# add eos_token_id from model's generation_config.json file if there
# is any.
eos_token_id = generation_config.get('eos_token_id')
if eos_token_id is not None:
if isinstance(eos_token_id, int):
stop_token_ids.add(eos_token_id)
_add_single(eos_token_id)
else:
stop_token_ids.update(eos_token_id)
for eid in eos_token_id:
_add_single(eid)

self.stop_token_ids = list(stop_token_ids)
self.stop_token_ids = stop_seqs

def __post_init__(self):
"""Check input validation."""
Expand All @@ -185,6 +230,8 @@ def __post_init__(self):
assert self.temperature >= 0 and self.temperature <= 2 # [0,2]
assert 0 <= self.min_p <= 1, \
f'min_p should be in range [0, 1], but found {self.min_p}'
if self.stop_token_ids is not None:
self.stop_token_ids = self._normalize_stop_token_ids(self.stop_token_ids)


@pydantic_dataclass
Expand Down
14 changes: 8 additions & 6 deletions lmdeploy/pytorch/engine/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def _filter_repetition_ngram_(
return scores
# use first stop words
_, found = ngram(generated_ids, n, threshold, max_n, max_ngram_window_size)
stop_words = stop_words[:, 0]
stop_words = stop_words[:, 0, 0]
# fill all scores -inf
scores.masked_fill_(found[:, None], -float('inf'))
# set stop words to 0
Expand Down Expand Up @@ -245,7 +245,7 @@ class SamplingInputs:
bad_words: torch.LongTensor = None
bad_mask: torch.BoolTensor = None
stop_words: torch.LongTensor = None
stop_mask: torch.BoolTensor = None
stop_word_lens: torch.LongTensor = None
repetition_penalty: torch.Tensor = None
top_k: torch.LongTensor = None
top_p: torch.Tensor = None
Expand Down Expand Up @@ -428,11 +428,13 @@ async def __call__(self, scores: torch.Tensor) -> torch.Tensor:
scores = _process_bad_words_(scores, bad_words, bad_mask)

stop_words = sampling_inputs.stop_words
if stop_words is not None:
stop_word_lens = sampling_inputs.stop_word_lens
if stop_words is not None and stop_word_lens is not None:
ignore_eos = sampling_inputs.num_ignore_eos > 0
stop_mask = sampling_inputs.stop_mask
stop_mask = torch.where(ignore_eos[:, None], stop_mask, False)
scores = _process_bad_words_(scores, stop_words, stop_mask)
single_mask = (stop_word_lens == 1) & ignore_eos[:, None]
if single_mask.any():
single_tokens = stop_words[:, :, 0]
scores = _process_bad_words_(scores, single_tokens, single_mask)

return scores, logprobs

Expand Down
6 changes: 5 additions & 1 deletion lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,11 +701,15 @@ async def _step_postprocess_with_output(self,
logger.debug(f'<ForwardTask> rank[{rank}]: synchronize token ids')

# stopping criteria
# Use output_token_ids (all tokens accepted this step) so that multi-token
# stop sequences whose last token is not the final spec-decoded token are
# detected correctly. For non-spec AR, output_token_ids == next_token_ids.
stopped, stop_pos, stopping_criteria = stopping_criteria.step(
next_token_ids,
output_token_ids,
sampling_inputs.stop_words,
inputs=inputs,
extra_inputs=extra_inputs,
stop_word_lens=sampling_inputs.stop_word_lens,
)
Comment on lines 703 to 713
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stopping_criteria.step(...) is now called with the new keyword argument stop_word_lens=..., but most StoppingCriteria implementations (e.g. ARSpecStoppingCriteria.step, DLLMStoppingCriteria.step, and the abstract base signature) do not accept this parameter. This will raise TypeError: got an unexpected keyword argument 'stop_word_lens' when using non-AR strategies. Either update the base StoppingCriteria.step contract (and all implementations) to accept stop_word_lens (or **kwargs), or only pass stop_word_lens when the concrete stopping criteria supports it.

Copilot uses AI. Check for mistakes.

# send output
Expand Down
8 changes: 6 additions & 2 deletions lmdeploy/pytorch/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class SamplingParam:
repetition_penalty: float = 1.0
ignore_eos: bool = False
random_seed: int = None
stop_words: list[int] = field(default_factory=list)
stop_words: list[list[int]] = field(default_factory=list)
bad_words: list[int] = field(default_factory=list)
max_new_tokens: int = 512
min_new_tokens: int = 0
Expand All @@ -75,7 +75,11 @@ def from_gen_config(cls, gen_config: GenerationConfig):
stop_words = gen_config.stop_token_ids or []
bad_words = gen_config.bad_token_ids or []
if gen_config.ignore_eos:
bad_words += stop_words
if any(len(s) > 1 for s in stop_words):
logger.warning('Multi-token stop words are not supported and '
'will be ignored. Only single-token stop words can '
'be used to stop generation.')
bad_words += [s[0] for s in stop_words if len(s) == 1]
stop_words = []

top_k = gen_config.top_k
Expand Down
Loading
Loading