Skip to content
Open

Fix mtp #4517

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@ def get_outputs(self,
"""Get outputs."""
hidden_states = model_outputs['hidden_states']
model_metas = model_outputs['model_metas']
if extra_inputs is not None and extra_inputs.last_token_indices is not None:
# for long input
if (not model_inputs.is_decoding) and model_inputs.seq_length.size(0) == 1:
hidden_states = hidden_states[:, -1:]
else:
last_token_loc = extra_inputs.last_token_indices
hidden_states = hidden_states[:, last_token_loc]
if extra_inputs is not None:
last_token_loc = extra_inputs.last_token_indices
target_hidden_states = model_inputs.target_hidden_states[:, last_token_loc]
hidden_states = hidden_states[:, last_token_loc]
else:
target_hidden_states = hidden_states

logits = self.get_logits(hidden_states)[0]
draft_token_ids = logits.argmax(dim=-1, keepdim=True)
return draft_token_ids, model_metas, hidden_states
return draft_token_ids, model_metas, target_hidden_states
100 changes: 34 additions & 66 deletions lmdeploy/pytorch/spec_decode/spec_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,69 +293,21 @@ def _prepare_long_context_chunk_prepend_saved(self, key, tensor, save_last=True)
self._prev_chunk_last.pop(key, None)
return torch.cat([saved, tensor], dim=1)

async def async_sampling_logits(self, target_logits: torch.Tensor, sampling_inputs: SamplingInputs):
"""Process target logits and sample bonus token using
FusedLogitsProcessor.

Args:
target_logits: [batch_size, num_tokens, vocab_size]
num_tokens = 1 + num_spec_tokens (decoding) or 1 (prefill)
sampling_inputs: SamplingInputs — already expanded by
make_sampling_inputs to batch_size * (num_spec_tokens + 1)

Returns:
processed_logits: [batch_size, num_tokens, vocab_size]
next_token_ids: [batch_size] — sampled from the bonus (last) position
logprobs: BatchedLogProbs or None
"""
with record_function('spec_sampling_logits'):
batch_size, num_tokens, vocab_size = target_logits.shape

# Reshape to 2D: [batch * num_tokens, vocab]
flat_logits = target_logits.reshape(-1, vocab_size)

# TODO: guided decoding not supported yet for spec decoding
# sampling_inputs is already expanded to batch_size * num_tokens
logits_processor = FusedLogitsProcessor(
sampling_inputs,
logprobs_mode=self.misc_config.logprobs_mode,
)
processed_logits, raw_logprobs = await logits_processor(flat_logits)

# Slice bonus (last) position logits for each batch element
bonus_logits = processed_logits[num_tokens - 1::num_tokens] # [batch_size, vocab]

# Create a per-batch processor for bonus token sampling
# by slicing the expanded sampling_inputs back to batch_size
bonus_sampling_inputs = _slice_sampling_inputs(sampling_inputs, num_tokens)
bonus_processor = FusedLogitsProcessor(
bonus_sampling_inputs,
logprobs_mode=self.misc_config.logprobs_mode,
)
# Sample next token from bonus position
next_token_ids = bonus_processor.sampling(bonus_logits) # [batch_size]

# Reshape back to 3D
processed_logits = processed_logits.view(batch_size, num_tokens, vocab_size)

return processed_logits, next_token_ids, raw_logprobs

async def _rejection_sampling(self, model_inputs: 'ModelInputs', extra_inputs: ARSpecExtraInputs,
sampling_inputs: SamplingInputs):
"""Do rejection sampling."""

@torch.inference_mode()
def __compute_logprobs(raw_logprobs: torch.Tensor, token_ids: torch.LongTensor,
sampling_inputs: SamplingInputs):
max_num_logprobs: int):
"""Compute logprobs."""
if raw_logprobs is None or sampling_inputs.max_num_logprobs <= 0:
if raw_logprobs is None or max_num_logprobs <= 0:
return None

indices = token_ids.flatten().unsqueeze(-1)
clamped_indices = indices.clamp_min(0)
logprobs = raw_logprobs.gather(-1, clamped_indices)
num_logprobs = sampling_inputs.max_num_logprobs
topk_logprobs, topk_indices = _torch_topk(raw_logprobs, num_logprobs, dim=-1)
topk_logprobs, topk_indices = _torch_topk(raw_logprobs, max_num_logprobs, dim=-1)
logprobs = torch.cat([logprobs, topk_logprobs], dim=-1)
indices = torch.cat([indices, topk_indices], dim=-1).to(torch.int32)
output_logprobs = BatchedLogProbs(
Expand All @@ -366,30 +318,47 @@ def __compute_logprobs(raw_logprobs: torch.Tensor, token_ids: torch.LongTensor,

# Process target_logits via FusedLogitsProcessor for BOTH prefill and decoding
target_logits = extra_inputs.target_logits
num_tokens = target_logits.shape[1]
expanded_sampling_inputs = _expand_sampling_inputs(sampling_inputs, num_tokens)
processed_logits, next_token_ids, raw_logprobs = await self.async_sampling_logits(
target_logits, expanded_sampling_inputs)
batch_size = model_inputs.seq_length.size(0)

num_expand_sampling = 1 if not model_inputs.is_decoding else self.num_spec_tokens + 1
expanded_sampling_inputs = _expand_sampling_inputs(sampling_inputs, num_expand_sampling)
num_rejected_tokens = torch.zeros_like(model_inputs.seq_length)
output_token_ids = next_token_ids.unsqueeze(-1)
last_token_indices = model_inputs.seq_length.cumsum(0) - 1

logits_processor = FusedLogitsProcessor(
expanded_sampling_inputs,
logprobs_mode=self.misc_config.logprobs_mode,
)
if model_inputs.is_decoding:
# TODO: guided decoding not supported yet for spec decoding
processed_logits, raw_logprobs = await logits_processor(target_logits)
# Slice bonus (last) position logits for each batch element
bonus_logits = processed_logits[num_expand_sampling - 1::num_expand_sampling] # [batch_size, vocab]
# Create a per-batch processor for bonus token sampling
# by slicing the expanded sampling_inputs back to batch_size
bonus_sampling_inputs = _slice_sampling_inputs(expanded_sampling_inputs, num_expand_sampling)
logits_processor.sampling_inputs = bonus_sampling_inputs
# Sample next token from bonus position
next_token_ids = logits_processor.sampling(bonus_logits) # [batch_size]
# Reshape back to 3D
processed_logits = processed_logits.view(batch_size, num_expand_sampling, -1)
# Rejection sampling on processed logits (exclude bonus position)
target_logits = processed_logits[:, :-1].contiguous() # [batch, num_spec, vocab]
num_tokens = self.num_spec_tokens + 1
batch_sampling_inputs = _slice_sampling_inputs(expanded_sampling_inputs, num_tokens, is_last=False)
target_draft_logits = processed_logits[:, :-1].contiguous() # [batch, num_spec, vocab]
draft_sampling_inputs = _slice_sampling_inputs(expanded_sampling_inputs, num_expand_sampling, is_last=False)
output_token_ids, num_rejected_tokens, next_token_ids = self.rejection_sampler(
target_logits,
target_draft_logits,
extra_inputs.output_draft_token_ids,
next_token_ids,
sampling_inputs=batch_sampling_inputs,
sampling_inputs=draft_sampling_inputs,
)
# update last token indices
last_token_indices = last_token_indices - num_rejected_tokens
else:
bonus_logits, raw_logprobs = await logits_processor(target_logits)
# Sample next token from bonus position
next_token_ids = logits_processor.sampling(bonus_logits) # [batch_size]
output_token_ids = next_token_ids.unsqueeze(-1)

logprobs = __compute_logprobs(raw_logprobs, output_token_ids, sampling_inputs)
logprobs = __compute_logprobs(raw_logprobs, output_token_ids, sampling_inputs.max_num_logprobs)

new_extra_inputs = extra_inputs.clone(
next_token_ids=next_token_ids,
Expand Down Expand Up @@ -439,8 +408,6 @@ async def _async_model_forward(self, inputs: ModelInputs, extra_inputs: ARSpecEx
inputs.target_hidden_states = target_hidden_states
if inputs.target_position_ids is not None:
inputs.target_position_ids += 1
if inputs.mrope_pos_ids is not None:
inputs.mrope_pos_ids += 1

output_draft_ids = torch.cat(draft_tokens_li, dim=-1)

Expand All @@ -461,7 +428,8 @@ async def async_model_forward(
sampling_inputs: SamplingInputs,
):
"""Draft model forward."""
draft_extra_inputs = await self._rejection_sampling(model_inputs, extra_inputs, sampling_inputs)
with record_function('spec_rejection_sampling'):
draft_extra_inputs = await self._rejection_sampling(model_inputs, extra_inputs, sampling_inputs)
draft_model_inputs, draft_extra_inputs = self._prepare_inputs_from_main(model_inputs, draft_extra_inputs)
return await self._async_model_forward(draft_model_inputs, draft_extra_inputs, sampling_inputs)

Expand Down
9 changes: 1 addition & 8 deletions lmdeploy/pytorch/strategies/ar_spec/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,7 @@ def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> t
def slice_extra_inputs(self, extra_inputs: ARSpecExtraInputs, model_inputs: ModelInputs,
model_outputs: dict[str, torch.Tensor], **kwargs) -> ARSpecExtraInputs:
"""Slice outputs."""
if model_inputs.is_decoding:
batch_size = model_inputs.seq_length.size(0)
raw_logits = model_outputs['logits'][0]
target_logits = raw_logits.unflatten(0, (batch_size, -1))
else:
# prefill: last token logits
raw_logits = model_outputs['logits'][0]
target_logits = raw_logits.unsqueeze(1)
target_logits = model_outputs['logits'][0]
return extra_inputs.clone(
target_hidden_states=model_outputs.get('hidden_states'),
target_position_ids=model_outputs.get('position_ids', None),
Expand Down
189 changes: 0 additions & 189 deletions tests/pytorch/spec_decode/test_spec_agent.py
Original file line number Diff line number Diff line change
@@ -1,199 +1,10 @@
import asyncio

import pytest
import torch

from lmdeploy.pytorch.spec_decode.spec_agent import _expand_sampling_inputs

device = 'cuda' if torch.cuda.is_available() else 'cpu'


def _run_async(coro):
"""Helper to run async function in sync test."""
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()


def _make_spec_agent_for_sampling(misc_config=None):
"""Create a minimal SpecModelAgent with only the fields needed for
async_sampling_logits."""
from lmdeploy.pytorch.config import MiscConfig
from lmdeploy.pytorch.spec_decode.spec_agent import SpecModelAgent

agent = object.__new__(SpecModelAgent)
agent.misc_config = misc_config or MiscConfig()
return agent


def test_async_sampling_logits_greedy_prefill():
"""Test async_sampling_logits with greedy sampling, prefill (1
position)."""
from lmdeploy.pytorch.engine.logits_process import SamplingInputs

agent = _make_spec_agent_for_sampling()

batch_size = 2
num_tokens_per_batch = 1 # prefill
vocab_size = 32

target_logits = torch.randn(batch_size, num_tokens_per_batch, vocab_size, device=device)
sampling_inputs = SamplingInputs(
max_top_k=1,
max_num_logprobs=-1,
logits_processors=[],
batch_size=batch_size,
)
# num_tokens_per_batch=1, no expansion needed

processed, next_ids, logprobs = _run_async(agent.async_sampling_logits(target_logits, sampling_inputs))

# Output shapes
assert processed.shape == (batch_size, num_tokens_per_batch, vocab_size)
assert next_ids.shape == (batch_size, )
# Greedy: next_ids should be argmax of bonus (only) position
expected = target_logits[:, -1, :].argmax(dim=-1)
torch.testing.assert_close(next_ids, expected)
# No logprobs requested
assert logprobs is None


def test_async_sampling_logits_greedy_decoding():
"""Test async_sampling_logits with greedy sampling, decoding (multiple
positions)."""
from lmdeploy.pytorch.engine.logits_process import SamplingInputs

agent = _make_spec_agent_for_sampling()

batch_size = 3
num_spec_tokens = 4
num_tokens_per_batch = 1 + num_spec_tokens # decoding
vocab_size = 64

target_logits = torch.randn(batch_size, num_tokens_per_batch, vocab_size, device=device)
sampling_inputs = SamplingInputs(
max_top_k=1,
max_num_logprobs=-1,
logits_processors=[],
batch_size=batch_size,
)
# Expand for decoding
expanded = _expand_sampling_inputs(sampling_inputs, num_tokens_per_batch)

processed, next_ids, logprobs = _run_async(agent.async_sampling_logits(target_logits, expanded))

assert processed.shape == (batch_size, num_tokens_per_batch, vocab_size)
assert next_ids.shape == (batch_size, )
# Greedy: bonus token is argmax of last position
expected = target_logits[:, -1, :].argmax(dim=-1)
torch.testing.assert_close(next_ids, expected)
assert logprobs is None


def test_async_sampling_logits_random():
"""Test async_sampling_logits with random sampling."""
from lmdeploy.pytorch.engine.logits_process import SamplingInputs

agent = _make_spec_agent_for_sampling()

batch_size = 2
num_tokens_per_batch = 4
vocab_size = 32

target_logits = torch.randn(batch_size, num_tokens_per_batch, vocab_size, device=device)
temperature = torch.ones(batch_size, device=device)
top_k = torch.full((batch_size, ), 10, device=device)
random_seeds = torch.randint(0, 2**31, (batch_size, ), dtype=torch.long, device=device)
random_offsets = torch.zeros(batch_size, dtype=torch.long, device=device)

sampling_inputs = SamplingInputs(
max_top_k=10,
top_k=top_k,
temperature=temperature,
random_seeds=random_seeds,
random_offsets=random_offsets,
max_num_logprobs=-1,
logits_processors=[],
batch_size=batch_size,
)
expanded = _expand_sampling_inputs(sampling_inputs, num_tokens_per_batch)

processed, next_ids, logprobs = _run_async(agent.async_sampling_logits(target_logits, expanded))

assert processed.shape == (batch_size, num_tokens_per_batch, vocab_size)
assert next_ids.shape == (batch_size, )
# Token ids should be valid
assert (next_ids >= 0).all()
assert (next_ids < vocab_size).all()
assert logprobs is None


def test_async_sampling_logits_with_logprobs():
"""Test async_sampling_logits returns logprobs when requested."""
from lmdeploy.pytorch.config import MiscConfig
from lmdeploy.pytorch.engine.logits_process import SamplingInputs

misc_config = MiscConfig(logprobs_mode='raw_logprobs')
agent = _make_spec_agent_for_sampling(misc_config=misc_config)

batch_size = 2
num_tokens_per_batch = 3
vocab_size = 32

target_logits = torch.randn(batch_size, num_tokens_per_batch, vocab_size, device=device)
sampling_inputs = SamplingInputs(
max_top_k=1,
max_num_logprobs=5,
logits_processors=[],
batch_size=batch_size,
)
expanded = _expand_sampling_inputs(sampling_inputs, num_tokens_per_batch)

processed, next_ids, logprobs = _run_async(agent.async_sampling_logits(target_logits, expanded))

assert processed.shape == (batch_size, num_tokens_per_batch, vocab_size)
assert next_ids.shape == (batch_size, )
assert logprobs is not None
# raw_logprobs shape: [batch_size * num_tokens_per_batch, vocab_size]
assert logprobs.shape == (batch_size * num_tokens_per_batch, vocab_size)


def test_async_sampling_logits_temperature():
"""Test that temperature scaling is applied correctly."""
from lmdeploy.pytorch.engine.logits_process import SamplingInputs

agent = _make_spec_agent_for_sampling()

batch_size = 1
num_tokens_per_batch = 1
vocab_size = 16

# Strong logits: token 0 = 100, rest = 0
target_logits = torch.zeros(batch_size, num_tokens_per_batch, vocab_size, device=device)
target_logits[0, 0, 0] = 100.0
original_val = target_logits[0, 0, 0].item()

temperature = torch.tensor([0.5], device=device)
sampling_inputs = SamplingInputs(
max_top_k=1,
temperature=temperature,
max_num_logprobs=-1,
logits_processors=[],
batch_size=batch_size,
)
# num_tokens_per_batch=1, no expansion needed

processed, next_ids, _ = _run_async(agent.async_sampling_logits(target_logits, sampling_inputs))

# Temperature divides logits: 100/0.5 = 200 at position 0
# Greedy should still pick token 0
assert next_ids[0] == 0
# Temperature-scaled logits should be 200
assert processed[0, 0, 0].item() == pytest.approx(original_val / 0.5, rel=1e-3)


def test_slice_sampling_inputs_decode():
"""Test _slice_sampling_inputs with decoding (num_tokens_per_batch > 1)."""
from lmdeploy.pytorch.engine.logits_process import SamplingInputs
Expand Down
Loading