Skip to content
Merged

Fix mtp #4517

Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@ def get_outputs(self,
"""Get outputs."""
hidden_states = model_outputs['hidden_states']
model_metas = model_outputs['model_metas']
if extra_inputs is not None and extra_inputs.last_token_indices is not None:
# for long input
if (not model_inputs.is_decoding) and model_inputs.seq_length.size(0) == 1:
hidden_states = hidden_states[:, -1:]
else:
last_token_loc = extra_inputs.last_token_indices
hidden_states = hidden_states[:, last_token_loc]
if extra_inputs is not None:
last_token_loc = extra_inputs.last_token_indices
target_hidden_states = model_inputs.target_hidden_states[:, last_token_loc]
hidden_states = hidden_states[:, last_token_loc]
else:
target_hidden_states = hidden_states

logits = self.get_logits(hidden_states)[0]
draft_token_ids = logits.argmax(dim=-1, keepdim=True)
return draft_token_ids, model_metas, hidden_states
return draft_token_ids, model_metas, target_hidden_states
100 changes: 34 additions & 66 deletions lmdeploy/pytorch/spec_decode/spec_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,69 +293,21 @@ def _prepare_long_context_chunk_prepend_saved(self, key, tensor, save_last=True)
self._prev_chunk_last.pop(key, None)
return torch.cat([saved, tensor], dim=1)

async def async_sampling_logits(self, target_logits: torch.Tensor, sampling_inputs: SamplingInputs):
"""Process target logits and sample bonus token using
FusedLogitsProcessor.

Args:
target_logits: [batch_size, num_tokens, vocab_size]
num_tokens = 1 + num_spec_tokens (decoding) or 1 (prefill)
sampling_inputs: SamplingInputs — already expanded by
make_sampling_inputs to batch_size * (num_spec_tokens + 1)

Returns:
processed_logits: [batch_size, num_tokens, vocab_size]
next_token_ids: [batch_size] — sampled from the bonus (last) position
logprobs: BatchedLogProbs or None
"""
with record_function('spec_sampling_logits'):
batch_size, num_tokens, vocab_size = target_logits.shape

# Reshape to 2D: [batch * num_tokens, vocab]
flat_logits = target_logits.reshape(-1, vocab_size)

# TODO: guided decoding not supported yet for spec decoding
# sampling_inputs is already expanded to batch_size * num_tokens
logits_processor = FusedLogitsProcessor(
sampling_inputs,
logprobs_mode=self.misc_config.logprobs_mode,
)
processed_logits, raw_logprobs = await logits_processor(flat_logits)

# Slice bonus (last) position logits for each batch element
bonus_logits = processed_logits[num_tokens - 1::num_tokens] # [batch_size, vocab]

# Create a per-batch processor for bonus token sampling
# by slicing the expanded sampling_inputs back to batch_size
bonus_sampling_inputs = _slice_sampling_inputs(sampling_inputs, num_tokens)
bonus_processor = FusedLogitsProcessor(
bonus_sampling_inputs,
logprobs_mode=self.misc_config.logprobs_mode,
)
# Sample next token from bonus position
next_token_ids = bonus_processor.sampling(bonus_logits) # [batch_size]

# Reshape back to 3D
processed_logits = processed_logits.view(batch_size, num_tokens, vocab_size)

return processed_logits, next_token_ids, raw_logprobs

async def _rejection_sampling(self, model_inputs: 'ModelInputs', extra_inputs: ARSpecExtraInputs,
sampling_inputs: SamplingInputs):
"""Do rejection sampling."""

@torch.inference_mode()
def __compute_logprobs(raw_logprobs: torch.Tensor, token_ids: torch.LongTensor,
sampling_inputs: SamplingInputs):
max_num_logprobs: int):
"""Compute logprobs."""
if raw_logprobs is None or sampling_inputs.max_num_logprobs <= 0:
if raw_logprobs is None or max_num_logprobs <= 0:
return None

indices = token_ids.flatten().unsqueeze(-1)
clamped_indices = indices.clamp_min(0)
logprobs = raw_logprobs.gather(-1, clamped_indices)
num_logprobs = sampling_inputs.max_num_logprobs
topk_logprobs, topk_indices = _torch_topk(raw_logprobs, num_logprobs, dim=-1)
topk_logprobs, topk_indices = _torch_topk(raw_logprobs, max_num_logprobs, dim=-1)
logprobs = torch.cat([logprobs, topk_logprobs], dim=-1)
indices = torch.cat([indices, topk_indices], dim=-1).to(torch.int32)
output_logprobs = BatchedLogProbs(
Expand All @@ -366,30 +318,47 @@ def __compute_logprobs(raw_logprobs: torch.Tensor, token_ids: torch.LongTensor,

# Process target_logits via FusedLogitsProcessor for BOTH prefill and decoding
target_logits = extra_inputs.target_logits
num_tokens = target_logits.shape[1]
expanded_sampling_inputs = _expand_sampling_inputs(sampling_inputs, num_tokens)
processed_logits, next_token_ids, raw_logprobs = await self.async_sampling_logits(
target_logits, expanded_sampling_inputs)
batch_size = model_inputs.seq_length.size(0)

num_expand_sampling = 1 if not model_inputs.is_decoding else self.num_spec_tokens + 1
expanded_sampling_inputs = _expand_sampling_inputs(sampling_inputs, num_expand_sampling)
num_rejected_tokens = torch.zeros_like(model_inputs.seq_length)
output_token_ids = next_token_ids.unsqueeze(-1)
last_token_indices = model_inputs.seq_length.cumsum(0) - 1

logits_processor = FusedLogitsProcessor(
expanded_sampling_inputs,
logprobs_mode=self.misc_config.logprobs_mode,
)
if model_inputs.is_decoding:
# TODO: guided decoding not supported yet for spec decoding
processed_logits, raw_logprobs = await logits_processor(target_logits)
# Slice bonus (last) position logits for each batch element
bonus_logits = processed_logits[num_expand_sampling - 1::num_expand_sampling] # [batch_size, vocab]
# Create a per-batch processor for bonus token sampling
# by slicing the expanded sampling_inputs back to batch_size
bonus_sampling_inputs = _slice_sampling_inputs(expanded_sampling_inputs, num_expand_sampling)
logits_processor.sampling_inputs = bonus_sampling_inputs
# Sample next token from bonus position
next_token_ids = logits_processor.sampling(bonus_logits) # [batch_size]
# Reshape back to 3D
processed_logits = processed_logits.view(batch_size, num_expand_sampling, -1)
# Rejection sampling on processed logits (exclude bonus position)
target_logits = processed_logits[:, :-1].contiguous() # [batch, num_spec, vocab]
num_tokens = self.num_spec_tokens + 1
batch_sampling_inputs = _slice_sampling_inputs(expanded_sampling_inputs, num_tokens, is_last=False)
target_draft_logits = processed_logits[:, :-1].contiguous() # [batch, num_spec, vocab]
draft_sampling_inputs = _slice_sampling_inputs(expanded_sampling_inputs, num_expand_sampling, is_last=False)
output_token_ids, num_rejected_tokens, next_token_ids = self.rejection_sampler(
target_logits,
target_draft_logits,
extra_inputs.output_draft_token_ids,
next_token_ids,
sampling_inputs=batch_sampling_inputs,
sampling_inputs=draft_sampling_inputs,
)
# update last token indices
last_token_indices = last_token_indices - num_rejected_tokens
else:
bonus_logits, raw_logprobs = await logits_processor(target_logits)
# Sample next token from bonus position
next_token_ids = logits_processor.sampling(bonus_logits) # [batch_size]
output_token_ids = next_token_ids.unsqueeze(-1)

logprobs = __compute_logprobs(raw_logprobs, output_token_ids, sampling_inputs)
logprobs = __compute_logprobs(raw_logprobs, output_token_ids, sampling_inputs.max_num_logprobs)

new_extra_inputs = extra_inputs.clone(
next_token_ids=next_token_ids,
Expand Down Expand Up @@ -439,8 +408,6 @@ async def _async_model_forward(self, inputs: ModelInputs, extra_inputs: ARSpecEx
inputs.target_hidden_states = target_hidden_states
if inputs.target_position_ids is not None:
inputs.target_position_ids += 1
if inputs.mrope_pos_ids is not None:
inputs.mrope_pos_ids += 1

output_draft_ids = torch.cat(draft_tokens_li, dim=-1)

Expand All @@ -461,7 +428,8 @@ async def async_model_forward(
sampling_inputs: SamplingInputs,
):
"""Draft model forward."""
draft_extra_inputs = await self._rejection_sampling(model_inputs, extra_inputs, sampling_inputs)
with record_function('spec_rejection_sampling'):
draft_extra_inputs = await self._rejection_sampling(model_inputs, extra_inputs, sampling_inputs)
draft_model_inputs, draft_extra_inputs = self._prepare_inputs_from_main(model_inputs, draft_extra_inputs)
return await self._async_model_forward(draft_model_inputs, draft_extra_inputs, sampling_inputs)

Expand Down
9 changes: 1 addition & 8 deletions lmdeploy/pytorch/strategies/ar_spec/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,7 @@ def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> t
def slice_extra_inputs(self, extra_inputs: ARSpecExtraInputs, model_inputs: ModelInputs,
model_outputs: dict[str, torch.Tensor], **kwargs) -> ARSpecExtraInputs:
"""Slice outputs."""
if model_inputs.is_decoding:
batch_size = model_inputs.seq_length.size(0)
raw_logits = model_outputs['logits'][0]
target_logits = raw_logits.unflatten(0, (batch_size, -1))
else:
# prefill: last token logits
raw_logits = model_outputs['logits'][0]
target_logits = raw_logits.unsqueeze(1)
target_logits = model_outputs['logits'][0]
return extra_inputs.clone(
target_hidden_states=model_outputs.get('hidden_states'),
target_position_ids=model_outputs.get('position_ids', None),
Expand Down
Loading