@@ -1839,26 +1839,62 @@ def current_input_and_position_ids(
18391839 self .token_to_pos_ids [:num_tokens ].unsqueeze (0 ),
18401840 )
18411841
1842- def last_token_logits (self , logits : Tensor ) -> Tensor :
1843- """Last tokens of logits.
1842+ def speculative_required_logit_indices (self , device : torch .device ) -> Tensor :
1843+ """Token-level indices needed for speculative decode verification.
1844+
1845+ Returns all decode token positions (base + speculative) concatenated
1846+ with the last token position of each prefill request.
18441847
18451848 Args:
1846- logits (Tensor ): Output logits of forward pass .
1849+ device (torch.device ): Device on which to create the index tensor .
18471850
18481851 Return:
1849- (Tensor) Last token logits.
1852+ (Tensor) 1-D indices into the packed token sequence, length
1853+ ``num_decode_requests * (num_speculative_tokens + 1) + num_prefill_requests``.
18501854 """
18511855 paused = self .paused_request_count
18521856 total = self .total_request_count
18531857 query_lengths = self .request_query_lengths [paused :total ]
1858+ num_decode = self .num_decode_requests
1859+
1860+ decode_token_count = num_decode * (self .num_speculative_tokens + 1 )
1861+ decode_indices = torch .arange (decode_token_count , device = device )
1862+
1863+ cumsum = torch .cumsum (query_lengths , dim = 0 )
1864+ prefill_last_indices = cumsum [num_decode :] - 1
1865+
1866+ return torch .cat ([decode_indices , prefill_last_indices ])
1867+
1868+ def last_token_logits (self , logits : Tensor ) -> Tensor :
1869+ """Select the logit positions needed for token generation.
1870+
1871+ When speculative decoding is active, decode requests need logits for all
1872+ their tokens (base + speculative) for verification, while prefill requests
1873+ only need the last token logit. This avoids materializing the full
1874+ vocab-sized logits for every prefill token, which causes large memory
1875+ spikes during prefill-heavy batches.
1876+
1877+ Args:
1878+ logits (Tensor): Output logits of forward pass, shape [1, S, H].
18541879
1880+ Return:
1881+ (Tensor) Selected logits, shape [N, H].
1882+ """
18551883 # todo: @lmcafee, remove these asserts?
18561884 assert logits .size (0 ) == 1 , f"logits.size(0) ({ tuple (logits .shape )} ) != 1"
18571885 assert logits .size (1 ) == self .padded_active_token_count , (
18581886 f"logits.size(1) ({ tuple (logits .shape )} ) != "
18591887 f"padded_active_token_count ({ self .padded_active_token_count } )."
18601888 )
18611889 logits_2d = logits .squeeze (0 )
1890+
1891+ if self .num_speculative_tokens > 0 :
1892+ selected = self .speculative_required_logit_indices (logits .device )
1893+ return logits_2d [selected , :]
1894+
1895+ paused = self .paused_request_count
1896+ total = self .total_request_count
1897+ query_lengths = self .request_query_lengths [paused :total ]
18621898 last_token_idxs = torch .cumsum (query_lengths , dim = 0 ) - 1
18631899 return logits_2d [last_token_idxs , :]
18641900
0 commit comments