diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ffb7266a5b2f..42a070e07162 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2973,9 +2973,16 @@ def _get_top_k_continuations( # Gather the top K scores from _all_ beams. if do_sample: - topk_indices = torch.multinomial( - nn.functional.softmax(accumulated_log_probs, dim=-1), num_samples=beams_to_keep - ) + probs = nn.functional.softmax(accumulated_log_probs, dim=-1) + # torch.multinomial on CUDA requires the last dimension to be <= 2**24. + # When num_beams * vocab_size exceeds this, pre-filter to the top candidates. + _MULTINOMIAL_MAX = 2**24 + if probs.shape[-1] > _MULTINOMIAL_MAX: + top_values, top_indices = torch.topk(probs, k=_MULTINOMIAL_MAX, dim=-1) + sampled = torch.multinomial(top_values, num_samples=beams_to_keep) + topk_indices = torch.gather(top_indices, dim=1, index=sampled) + else: + topk_indices = torch.multinomial(probs, num_samples=beams_to_keep) topk_log_probs = torch.gather(input=accumulated_log_probs, dim=1, index=topk_indices) else: topk_log_probs, topk_indices = torch.topk(accumulated_log_probs, k=beams_to_keep)