diff --git a/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py b/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py index 09e4c08591..c0d65ca33b 100644 --- a/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py +++ b/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py @@ -21,14 +21,13 @@ def get_outputs(self, """Get outputs.""" hidden_states = model_outputs['hidden_states'] model_metas = model_outputs['model_metas'] - if extra_inputs is not None and extra_inputs.last_token_indices is not None: - # for long input - if (not model_inputs.is_decoding) and model_inputs.seq_length.size(0) == 1: - hidden_states = hidden_states[:, -1:] - else: - last_token_loc = extra_inputs.last_token_indices - hidden_states = hidden_states[:, last_token_loc] + if extra_inputs is not None: + last_token_loc = extra_inputs.last_token_indices + target_hidden_states = model_inputs.target_hidden_states[:, last_token_loc] + hidden_states = hidden_states[:, last_token_loc] + else: + target_hidden_states = hidden_states logits = self.get_logits(hidden_states)[0] draft_token_ids = logits.argmax(dim=-1, keepdim=True) - return draft_token_ids, model_metas, hidden_states + return draft_token_ids, model_metas, target_hidden_states diff --git a/lmdeploy/pytorch/spec_decode/spec_agent.py b/lmdeploy/pytorch/spec_decode/spec_agent.py index 9aa156ebfe..5bfe5275ae 100644 --- a/lmdeploy/pytorch/spec_decode/spec_agent.py +++ b/lmdeploy/pytorch/spec_decode/spec_agent.py @@ -293,69 +293,21 @@ def _prepare_long_context_chunk_prepend_saved(self, key, tensor, save_last=True) self._prev_chunk_last.pop(key, None) return torch.cat([saved, tensor], dim=1) - async def async_sampling_logits(self, target_logits: torch.Tensor, sampling_inputs: SamplingInputs): - """Process target logits and sample bonus token using - FusedLogitsProcessor. - - Args: - target_logits: [batch_size, num_tokens, vocab_size] - num_tokens = 1 + num_spec_tokens (decoding) or 1 (prefill) - sampling_inputs: SamplingInputs — already expanded by - make_sampling_inputs to batch_size * (num_spec_tokens + 1) - - Returns: - processed_logits: [batch_size, num_tokens, vocab_size] - next_token_ids: [batch_size] — sampled from the bonus (last) position - logprobs: BatchedLogProbs or None - """ - with record_function('spec_sampling_logits'): - batch_size, num_tokens, vocab_size = target_logits.shape - - # Reshape to 2D: [batch * num_tokens, vocab] - flat_logits = target_logits.reshape(-1, vocab_size) - - # TODO: guided decoding not supported yet for spec decoding - # sampling_inputs is already expanded to batch_size * num_tokens - logits_processor = FusedLogitsProcessor( - sampling_inputs, - logprobs_mode=self.misc_config.logprobs_mode, - ) - processed_logits, raw_logprobs = await logits_processor(flat_logits) - - # Slice bonus (last) position logits for each batch element - bonus_logits = processed_logits[num_tokens - 1::num_tokens] # [batch_size, vocab] - - # Create a per-batch processor for bonus token sampling - # by slicing the expanded sampling_inputs back to batch_size - bonus_sampling_inputs = _slice_sampling_inputs(sampling_inputs, num_tokens) - bonus_processor = FusedLogitsProcessor( - bonus_sampling_inputs, - logprobs_mode=self.misc_config.logprobs_mode, - ) - # Sample next token from bonus position - next_token_ids = bonus_processor.sampling(bonus_logits) # [batch_size] - - # Reshape back to 3D - processed_logits = processed_logits.view(batch_size, num_tokens, vocab_size) - - return processed_logits, next_token_ids, raw_logprobs - async def _rejection_sampling(self, model_inputs: 'ModelInputs', extra_inputs: ARSpecExtraInputs, sampling_inputs: SamplingInputs): """Do rejection sampling.""" @torch.inference_mode() def __compute_logprobs(raw_logprobs: torch.Tensor, token_ids: torch.LongTensor, - sampling_inputs: SamplingInputs): + max_num_logprobs: int): """Compute logprobs.""" - if raw_logprobs is None or sampling_inputs.max_num_logprobs <= 0: + if raw_logprobs is None or max_num_logprobs <= 0: return None indices = token_ids.flatten().unsqueeze(-1) clamped_indices = indices.clamp_min(0) logprobs = raw_logprobs.gather(-1, clamped_indices) - num_logprobs = sampling_inputs.max_num_logprobs - topk_logprobs, topk_indices = _torch_topk(raw_logprobs, num_logprobs, dim=-1) + topk_logprobs, topk_indices = _torch_topk(raw_logprobs, max_num_logprobs, dim=-1) logprobs = torch.cat([logprobs, topk_logprobs], dim=-1) indices = torch.cat([indices, topk_indices], dim=-1).to(torch.int32) output_logprobs = BatchedLogProbs( @@ -366,30 +318,47 @@ def __compute_logprobs(raw_logprobs: torch.Tensor, token_ids: torch.LongTensor, # Process target_logits via FusedLogitsProcessor for BOTH prefill and decoding target_logits = extra_inputs.target_logits - num_tokens = target_logits.shape[1] - expanded_sampling_inputs = _expand_sampling_inputs(sampling_inputs, num_tokens) - processed_logits, next_token_ids, raw_logprobs = await self.async_sampling_logits( - target_logits, expanded_sampling_inputs) + batch_size = model_inputs.seq_length.size(0) + num_expand_sampling = 1 if not model_inputs.is_decoding else self.num_spec_tokens + 1 + expanded_sampling_inputs = _expand_sampling_inputs(sampling_inputs, num_expand_sampling) num_rejected_tokens = torch.zeros_like(model_inputs.seq_length) - output_token_ids = next_token_ids.unsqueeze(-1) last_token_indices = model_inputs.seq_length.cumsum(0) - 1 - + logits_processor = FusedLogitsProcessor( + expanded_sampling_inputs, + logprobs_mode=self.misc_config.logprobs_mode, + ) if model_inputs.is_decoding: + # TODO: guided decoding not supported yet for spec decoding + processed_logits, raw_logprobs = await logits_processor(target_logits) + # Slice bonus (last) position logits for each batch element + bonus_logits = processed_logits[num_expand_sampling - 1::num_expand_sampling] # [batch_size, vocab] + # Create a per-batch processor for bonus token sampling + # by slicing the expanded sampling_inputs back to batch_size + bonus_sampling_inputs = _slice_sampling_inputs(expanded_sampling_inputs, num_expand_sampling) + logits_processor.sampling_inputs = bonus_sampling_inputs + # Sample next token from bonus position + next_token_ids = logits_processor.sampling(bonus_logits) # [batch_size] + # Reshape back to 3D + processed_logits = processed_logits.view(batch_size, num_expand_sampling, -1) # Rejection sampling on processed logits (exclude bonus position) - target_logits = processed_logits[:, :-1].contiguous() # [batch, num_spec, vocab] - num_tokens = self.num_spec_tokens + 1 - batch_sampling_inputs = _slice_sampling_inputs(expanded_sampling_inputs, num_tokens, is_last=False) + target_draft_logits = processed_logits[:, :-1].contiguous() # [batch, num_spec, vocab] + draft_sampling_inputs = _slice_sampling_inputs(expanded_sampling_inputs, num_expand_sampling, is_last=False) output_token_ids, num_rejected_tokens, next_token_ids = self.rejection_sampler( - target_logits, + target_draft_logits, extra_inputs.output_draft_token_ids, next_token_ids, - sampling_inputs=batch_sampling_inputs, + sampling_inputs=draft_sampling_inputs, ) # update last token indices last_token_indices = last_token_indices - num_rejected_tokens + else: + bonus_logits, raw_logprobs = await logits_processor(target_logits) + # Sample next token from bonus position + next_token_ids = logits_processor.sampling(bonus_logits) # [batch_size] + output_token_ids = next_token_ids.unsqueeze(-1) - logprobs = __compute_logprobs(raw_logprobs, output_token_ids, sampling_inputs) + logprobs = __compute_logprobs(raw_logprobs, output_token_ids, sampling_inputs.max_num_logprobs) new_extra_inputs = extra_inputs.clone( next_token_ids=next_token_ids, @@ -439,8 +408,6 @@ async def _async_model_forward(self, inputs: ModelInputs, extra_inputs: ARSpecEx inputs.target_hidden_states = target_hidden_states if inputs.target_position_ids is not None: inputs.target_position_ids += 1 - if inputs.mrope_pos_ids is not None: - inputs.mrope_pos_ids += 1 output_draft_ids = torch.cat(draft_tokens_li, dim=-1) @@ -461,7 +428,8 @@ async def async_model_forward( sampling_inputs: SamplingInputs, ): """Draft model forward.""" - draft_extra_inputs = await self._rejection_sampling(model_inputs, extra_inputs, sampling_inputs) + with record_function('spec_rejection_sampling'): + draft_extra_inputs = await self._rejection_sampling(model_inputs, extra_inputs, sampling_inputs) draft_model_inputs, draft_extra_inputs = self._prepare_inputs_from_main(model_inputs, draft_extra_inputs) return await self._async_model_forward(draft_model_inputs, draft_extra_inputs, sampling_inputs) diff --git a/lmdeploy/pytorch/strategies/ar_spec/model_agent.py b/lmdeploy/pytorch/strategies/ar_spec/model_agent.py index f66ee61d94..a948ed30b6 100644 --- a/lmdeploy/pytorch/strategies/ar_spec/model_agent.py +++ b/lmdeploy/pytorch/strategies/ar_spec/model_agent.py @@ -151,14 +151,7 @@ def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> t def slice_extra_inputs(self, extra_inputs: ARSpecExtraInputs, model_inputs: ModelInputs, model_outputs: dict[str, torch.Tensor], **kwargs) -> ARSpecExtraInputs: """Slice outputs.""" - if model_inputs.is_decoding: - batch_size = model_inputs.seq_length.size(0) - raw_logits = model_outputs['logits'][0] - target_logits = raw_logits.unflatten(0, (batch_size, -1)) - else: - # prefill: last token logits - raw_logits = model_outputs['logits'][0] - target_logits = raw_logits.unsqueeze(1) + target_logits = model_outputs['logits'][0] return extra_inputs.clone( target_hidden_states=model_outputs.get('hidden_states'), target_position_ids=model_outputs.get('position_ids', None), diff --git a/tests/pytorch/spec_decode/test_spec_agent.py b/tests/pytorch/spec_decode/test_spec_agent.py index cfdf85c36b..91dc444d64 100644 --- a/tests/pytorch/spec_decode/test_spec_agent.py +++ b/tests/pytorch/spec_decode/test_spec_agent.py @@ -1,6 +1,3 @@ -import asyncio - -import pytest import torch from lmdeploy.pytorch.spec_decode.spec_agent import _expand_sampling_inputs @@ -8,192 +5,6 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu' -def _run_async(coro): - """Helper to run async function in sync test.""" - loop = asyncio.new_event_loop() - try: - return loop.run_until_complete(coro) - finally: - loop.close() - - -def _make_spec_agent_for_sampling(misc_config=None): - """Create a minimal SpecModelAgent with only the fields needed for - async_sampling_logits.""" - from lmdeploy.pytorch.config import MiscConfig - from lmdeploy.pytorch.spec_decode.spec_agent import SpecModelAgent - - agent = object.__new__(SpecModelAgent) - agent.misc_config = misc_config or MiscConfig() - return agent - - -def test_async_sampling_logits_greedy_prefill(): - """Test async_sampling_logits with greedy sampling, prefill (1 - position).""" - from lmdeploy.pytorch.engine.logits_process import SamplingInputs - - agent = _make_spec_agent_for_sampling() - - batch_size = 2 - num_tokens_per_batch = 1 # prefill - vocab_size = 32 - - target_logits = torch.randn(batch_size, num_tokens_per_batch, vocab_size, device=device) - sampling_inputs = SamplingInputs( - max_top_k=1, - max_num_logprobs=-1, - logits_processors=[], - batch_size=batch_size, - ) - # num_tokens_per_batch=1, no expansion needed - - processed, next_ids, logprobs = _run_async(agent.async_sampling_logits(target_logits, sampling_inputs)) - - # Output shapes - assert processed.shape == (batch_size, num_tokens_per_batch, vocab_size) - assert next_ids.shape == (batch_size, ) - # Greedy: next_ids should be argmax of bonus (only) position - expected = target_logits[:, -1, :].argmax(dim=-1) - torch.testing.assert_close(next_ids, expected) - # No logprobs requested - assert logprobs is None - - -def test_async_sampling_logits_greedy_decoding(): - """Test async_sampling_logits with greedy sampling, decoding (multiple - positions).""" - from lmdeploy.pytorch.engine.logits_process import SamplingInputs - - agent = _make_spec_agent_for_sampling() - - batch_size = 3 - num_spec_tokens = 4 - num_tokens_per_batch = 1 + num_spec_tokens # decoding - vocab_size = 64 - - target_logits = torch.randn(batch_size, num_tokens_per_batch, vocab_size, device=device) - sampling_inputs = SamplingInputs( - max_top_k=1, - max_num_logprobs=-1, - logits_processors=[], - batch_size=batch_size, - ) - # Expand for decoding - expanded = _expand_sampling_inputs(sampling_inputs, num_tokens_per_batch) - - processed, next_ids, logprobs = _run_async(agent.async_sampling_logits(target_logits, expanded)) - - assert processed.shape == (batch_size, num_tokens_per_batch, vocab_size) - assert next_ids.shape == (batch_size, ) - # Greedy: bonus token is argmax of last position - expected = target_logits[:, -1, :].argmax(dim=-1) - torch.testing.assert_close(next_ids, expected) - assert logprobs is None - - -def test_async_sampling_logits_random(): - """Test async_sampling_logits with random sampling.""" - from lmdeploy.pytorch.engine.logits_process import SamplingInputs - - agent = _make_spec_agent_for_sampling() - - batch_size = 2 - num_tokens_per_batch = 4 - vocab_size = 32 - - target_logits = torch.randn(batch_size, num_tokens_per_batch, vocab_size, device=device) - temperature = torch.ones(batch_size, device=device) - top_k = torch.full((batch_size, ), 10, device=device) - random_seeds = torch.randint(0, 2**31, (batch_size, ), dtype=torch.long, device=device) - random_offsets = torch.zeros(batch_size, dtype=torch.long, device=device) - - sampling_inputs = SamplingInputs( - max_top_k=10, - top_k=top_k, - temperature=temperature, - random_seeds=random_seeds, - random_offsets=random_offsets, - max_num_logprobs=-1, - logits_processors=[], - batch_size=batch_size, - ) - expanded = _expand_sampling_inputs(sampling_inputs, num_tokens_per_batch) - - processed, next_ids, logprobs = _run_async(agent.async_sampling_logits(target_logits, expanded)) - - assert processed.shape == (batch_size, num_tokens_per_batch, vocab_size) - assert next_ids.shape == (batch_size, ) - # Token ids should be valid - assert (next_ids >= 0).all() - assert (next_ids < vocab_size).all() - assert logprobs is None - - -def test_async_sampling_logits_with_logprobs(): - """Test async_sampling_logits returns logprobs when requested.""" - from lmdeploy.pytorch.config import MiscConfig - from lmdeploy.pytorch.engine.logits_process import SamplingInputs - - misc_config = MiscConfig(logprobs_mode='raw_logprobs') - agent = _make_spec_agent_for_sampling(misc_config=misc_config) - - batch_size = 2 - num_tokens_per_batch = 3 - vocab_size = 32 - - target_logits = torch.randn(batch_size, num_tokens_per_batch, vocab_size, device=device) - sampling_inputs = SamplingInputs( - max_top_k=1, - max_num_logprobs=5, - logits_processors=[], - batch_size=batch_size, - ) - expanded = _expand_sampling_inputs(sampling_inputs, num_tokens_per_batch) - - processed, next_ids, logprobs = _run_async(agent.async_sampling_logits(target_logits, expanded)) - - assert processed.shape == (batch_size, num_tokens_per_batch, vocab_size) - assert next_ids.shape == (batch_size, ) - assert logprobs is not None - # raw_logprobs shape: [batch_size * num_tokens_per_batch, vocab_size] - assert logprobs.shape == (batch_size * num_tokens_per_batch, vocab_size) - - -def test_async_sampling_logits_temperature(): - """Test that temperature scaling is applied correctly.""" - from lmdeploy.pytorch.engine.logits_process import SamplingInputs - - agent = _make_spec_agent_for_sampling() - - batch_size = 1 - num_tokens_per_batch = 1 - vocab_size = 16 - - # Strong logits: token 0 = 100, rest = 0 - target_logits = torch.zeros(batch_size, num_tokens_per_batch, vocab_size, device=device) - target_logits[0, 0, 0] = 100.0 - original_val = target_logits[0, 0, 0].item() - - temperature = torch.tensor([0.5], device=device) - sampling_inputs = SamplingInputs( - max_top_k=1, - temperature=temperature, - max_num_logprobs=-1, - logits_processors=[], - batch_size=batch_size, - ) - # num_tokens_per_batch=1, no expansion needed - - processed, next_ids, _ = _run_async(agent.async_sampling_logits(target_logits, sampling_inputs)) - - # Temperature divides logits: 100/0.5 = 200 at position 0 - # Greedy should still pick token 0 - assert next_ids[0] == 0 - # Temperature-scaled logits should be 200 - assert processed[0, 0, 0].item() == pytest.approx(original_val / 0.5, rel=1e-3) - - def test_slice_sampling_inputs_decode(): """Test _slice_sampling_inputs with decoding (num_tokens_per_batch > 1).""" from lmdeploy.pytorch.engine.logits_process import SamplingInputs