Skip to content

Commit 9ae1737

Browse files
authored
Merge branch 'main' into cye/decgrad-argfix
2 parents f30d840 + c146305 commit 9ae1737

File tree

8 files changed

+960
-190
lines changed

8 files changed

+960
-190
lines changed

megatron/core/inference/contexts/dynamic_context.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1839,26 +1839,62 @@ def current_input_and_position_ids(
18391839
self.token_to_pos_ids[:num_tokens].unsqueeze(0),
18401840
)
18411841

1842-
def last_token_logits(self, logits: Tensor) -> Tensor:
1843-
"""Last tokens of logits.
1842+
def speculative_required_logit_indices(self, device: torch.device) -> Tensor:
1843+
"""Token-level indices needed for speculative decode verification.
1844+
1845+
Returns all decode token positions (base + speculative) concatenated
1846+
with the last token position of each prefill request.
18441847
18451848
Args:
1846-
logits (Tensor): Output logits of forward pass.
1849+
device (torch.device): Device on which to create the index tensor.
18471850
18481851
Return:
1849-
(Tensor) Last token logits.
1852+
(Tensor) 1-D indices into the packed token sequence, length
1853+
``num_decode_requests * (num_speculative_tokens + 1) + num_prefill_requests``.
18501854
"""
18511855
paused = self.paused_request_count
18521856
total = self.total_request_count
18531857
query_lengths = self.request_query_lengths[paused:total]
1858+
num_decode = self.num_decode_requests
1859+
1860+
decode_token_count = num_decode * (self.num_speculative_tokens + 1)
1861+
decode_indices = torch.arange(decode_token_count, device=device)
1862+
1863+
cumsum = torch.cumsum(query_lengths, dim=0)
1864+
prefill_last_indices = cumsum[num_decode:] - 1
1865+
1866+
return torch.cat([decode_indices, prefill_last_indices])
1867+
1868+
def last_token_logits(self, logits: Tensor) -> Tensor:
1869+
"""Select the logit positions needed for token generation.
1870+
1871+
When speculative decoding is active, decode requests need logits for all
1872+
their tokens (base + speculative) for verification, while prefill requests
1873+
only need the last token logit. This avoids materializing the full
1874+
vocab-sized logits for every prefill token, which causes large memory
1875+
spikes during prefill-heavy batches.
1876+
1877+
Args:
1878+
logits (Tensor): Output logits of forward pass, shape [1, S, H].
18541879
1880+
Return:
1881+
(Tensor) Selected logits, shape [N, H].
1882+
"""
18551883
# todo: @lmcafee, remove these asserts?
18561884
assert logits.size(0) == 1, f"logits.size(0) ({tuple(logits.shape)}) != 1"
18571885
assert logits.size(1) == self.padded_active_token_count, (
18581886
f"logits.size(1) ({tuple(logits.shape)}) != "
18591887
f"padded_active_token_count ({self.padded_active_token_count})."
18601888
)
18611889
logits_2d = logits.squeeze(0)
1890+
1891+
if self.num_speculative_tokens > 0:
1892+
selected = self.speculative_required_logit_indices(logits.device)
1893+
return logits_2d[selected, :]
1894+
1895+
paused = self.paused_request_count
1896+
total = self.total_request_count
1897+
query_lengths = self.request_query_lengths[paused:total]
18621898
last_token_idxs = torch.cumsum(query_lengths, dim=0) - 1
18631899
return logits_2d[last_token_idxs, :]
18641900

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,9 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen
214214

215215
if self.num_speculative_tokens > 0:
216216
assert (
217-
self.num_speculative_tokens <= self.controller.num_mtp_heads
217+
model_config.mtp_use_repeated_layer
218+
or self.num_speculative_tokens <= self.controller.num_mtp_heads
218219
), f"Number of speculative tokens {self.num_speculative_tokens} must be less than or equal to number of MTP heads {self.controller.num_mtp_heads}"
219-
assert (
220-
not self.materialize_only_last_token_logits
221-
), "materialize_only_last_token_logits must be False when num_speculative_tokens > 0"
222-
223220
self.track_paused_request_events = inference_config.track_paused_request_events
224221
self.track_generated_token_events = inference_config.track_generated_token_events
225222
self.enable_chunked_prefill = inference_config.enable_chunked_prefill
@@ -1211,7 +1208,13 @@ def post_process_requests(
12111208
top_n_logprobs[req_idx] = top_n_logprobs[req_idx][:-num_stop_word_trim]
12121209

12131210
# Process log_probs if available (unified for both regular and chunked prefill)
1214-
if request_log_probs is not None:
1211+
# Skip for requests being finished due to stop words — tokens are not
1212+
# appended for these requests, so log probs must also be skipped to keep
1213+
# the two lists in sync.
1214+
if (
1215+
request_log_probs is not None
1216+
and request_id not in self.stop_word_being_finished_ids
1217+
):
12151218
# Initialize lists if they don't exist
12161219
if not request.prompt_log_probs:
12171220
request.prompt_log_probs = []
@@ -1244,7 +1247,12 @@ def post_process_requests(
12441247
request.generated_log_probs.extend(request_log_probs[split_idx:])
12451248

12461249
# Process top_n_logprobs if available (unified for both regular and chunked prefill)
1247-
if top_n_logprobs is not None and req_idx in top_n_logprobs:
1250+
# Same stop-word guard as log probs above.
1251+
if (
1252+
top_n_logprobs is not None
1253+
and req_idx in top_n_logprobs
1254+
and request_id not in self.stop_word_being_finished_ids
1255+
):
12481256
# Initialize lists if they don't exist
12491257
if request.prompt_top_n_logprobs is None:
12501258
request.prompt_top_n_logprobs = []

0 commit comments

Comments
 (0)