From b07a530ff95d9fc2cc91f6552282db629d419875 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 6 Apr 2026 22:18:33 +0000 Subject: [PATCH 01/10] fix fa2 cb max_bt graph reuse --- .../continuous_batching/continuous_api.py | 16 +++++ .../continuous_batching/input_outputs.py | 64 ++++++++++++++++--- .../generation/continuous_batching/utils.py | 15 +++-- tests/generation/test_continuous_batching.py | 60 ++++++++++++++++- 4 files changed, 139 insertions(+), 16 deletions(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 4e2ef2a24adb..5709fc9ccb05 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -391,6 +391,22 @@ def prepare_next_batch(self) -> bool: self.inputs_and_outputs.prepare_batch_tensors( requests_in_batch, self.logit_processor, use_decode_fast_path, num_q_tokens, max_kv_read ) + if self._pad_inputs: + if self.inputs_and_outputs.use_block_table: + self.inputs_and_outputs.set_graph_bounds(1, 1) + else: + padded_max_seqlen_q = pad_to_interval( + self.inputs_and_outputs.max_seqlen_q, self.q_padding_interval_size, self.inputs_and_outputs.num_q_tokens + ) + padded_max_seqlen_k = { + layer_type: pad_to_interval( + self.inputs_and_outputs.max_seqlen_k[layer_type], + self.kv_padding_interval_size, + self.inputs_and_outputs.max_kv_read + self.inputs_and_outputs.num_q_tokens, + ) + for layer_type in self.inputs_and_outputs.max_seqlen_k + } + self.inputs_and_outputs.set_graph_bounds(padded_max_seqlen_q, padded_max_seqlen_k) self.metrics.record_kv_cache_memory_metrics(self.cache) return True diff --git a/src/transformers/generation/continuous_batching/input_outputs.py b/src/transformers/generation/continuous_batching/input_outputs.py index b5c4e5419bfc..281ea00e5b93 100644 --- a/src/transformers/generation/continuous_batching/input_outputs.py +++ b/src/transformers/generation/continuous_batching/input_outputs.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Hashable from contextlib import nullcontext from dataclasses import dataclass from functools import partial @@ -128,6 +129,8 @@ def __init__( self.true_read_sizes = [0 for _ in range(cache.num_groups)] self.true_write_sizes = [0 for _ in range(cache.num_groups)] self.use_block_table = False # True if all requests in batch have query_length == 1 + self.graph_max_seqlen_q = 0 + self.graph_max_seqlen_k: dict[str, int] = {} # Setup other accumulators self.requests_in_batch: list[FutureRequestState] = [] self.req_id_to_new_token_position: dict[str, int] = {} # only used for async API @@ -194,6 +197,7 @@ def _setup_static_tensors(self, logit_processor: ContinuousBatchingLogitsProcess self.total_seqlen_q = 0 self.max_seqlen_q = 0 self.max_seqlen_k = dict.fromkeys(self.cumulative_seqlens_k.keys(), 0) + self.graph_max_seqlen_k = dict.fromkeys(self.cumulative_seqlens_k.keys(), 0) # If the attention mask is needed, it is allocated separately if attn_mask_is_needed(self.config): @@ -276,10 +280,12 @@ def _reset_static_tensors(self, full_reset: bool = False) -> None: # Reset the logits indices and output ids self.logits_indices[:q_len].zero_() self.output_ids[:, :q_len].zero_() + self.graph_max_seqlen_q = 0 # Reset the attributes that are either tensors or dict of tensors for layer_type in self.cumulative_seqlens_k: self.max_seqlen_k[layer_type] = 0 + self.graph_max_seqlen_k[layer_type] = 0 if self.attention_mask is not None: self.attention_mask[layer_type][:, :, :q_len, : q_len + kv_len].fill_( torch.finfo(self.model_dtype).min @@ -460,13 +466,14 @@ def get_model_kwargs(self, use_padding: bool = False) -> dict[str, Any]: q_size = self.num_q_tokens kv_size = self.max_kv_read + self.num_q_tokens batch_size = self.num_q_tokens if use_padding else self.true_batch_size + effective_max_seqlen_q = self.graph_max_seqlen_q if use_padding and self.graph_max_seqlen_q > 0 else self.max_seqlen_q # Prepare the kwargs, the attributes that are either tensors or dict of tensors are initialized to empty dicts. kwargs = PagedAttentionArgs( input_ids=self.input_ids[:q_size].unsqueeze(0), position_ids=self.position_ids[:q_size].unsqueeze(0), cu_seq_lens_q=self.cumulative_seqlens_q[: batch_size + 1], - max_seqlen_q=self.max_seqlen_q, + max_seqlen_q=effective_max_seqlen_q, logits_indices=self.logits_indices[:q_size], logits_processor_args=self._bulk_input_tensor[self.static_inputs :, :q_size], cu_seq_lens_k={}, @@ -483,8 +490,8 @@ def get_model_kwargs(self, use_padding: bool = False) -> dict[str, Any]: # some models like Qwen3-4B-Instruct-2507, if we don't include these tokens in cumulative_seqlens_q, there are # some NaNs in the output logits even for non-padded tokens. if use_padding: - self.max_seqlen_q = max(self.max_seqlen_q, q_size - self.total_seqlen_q) - kwargs.max_seqlen_q = self.max_seqlen_q + effective_max_seqlen_q = max(effective_max_seqlen_q, q_size - self.total_seqlen_q) + kwargs.max_seqlen_q = effective_max_seqlen_q self.cumulative_seqlens_q[self.true_batch_size + 1 :] = q_size # FIXME: is there another way to avoid this? It has a very slight impact on performance (~5 tok/s) @@ -509,14 +516,24 @@ def get_model_kwargs(self, use_padding: bool = False) -> dict[str, Any]: kwargs.attention_mask: dict[str, torch.Tensor] = {} for layer_type, seqlens_k in self.cumulative_seqlens_k.items(): kwargs.cu_seq_lens_k[layer_type] = seqlens_k[: batch_size + 1] - kwargs.max_seqlen_k[layer_type] = 1 if self.use_block_table else self.max_seqlen_k[layer_type] + graph_max_seqlen_k = ( + self.graph_max_seqlen_k[layer_type] + if use_padding and self.graph_max_seqlen_k[layer_type] > 0 + else self.max_seqlen_k[layer_type] + ) + kwargs.max_seqlen_k[layer_type] = 1 if self.use_block_table else graph_max_seqlen_k if self.attention_mask is not None: k_len = kv_size if use_padding else seqlens_k[batch_size] kwargs.attention_mask[layer_type] = self.attention_mask[layer_type][..., :q_size, :k_len] else: layer_type = layer_types[0] kwargs.cu_seq_lens_k = self.cumulative_seqlens_k[layer_type][: batch_size + 1] - kwargs.max_seqlen_k = 1 if self.use_block_table else self.max_seqlen_k[layer_type] + graph_max_seqlen_k = ( + self.graph_max_seqlen_k[layer_type] + if use_padding and self.graph_max_seqlen_k[layer_type] > 0 + else self.max_seqlen_k[layer_type] + ) + kwargs.max_seqlen_k = 1 if self.use_block_table else graph_max_seqlen_k if self.attention_mask is not None: k_len = kv_size if use_padding else self.cumulative_seqlens_k[layer_type][batch_size] kwargs.attention_mask = self.attention_mask[layer_type][..., :q_size, :k_len] @@ -531,16 +548,47 @@ def get_cb_kwargs(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return 3 tensors to have the same interface as when using async batching.""" return self.carry_over_ids, self.output_ids, self.output_ids + def set_graph_bounds(self, max_seqlen_q: int, max_seqlen_k: int | dict[str, int]) -> None: + self.graph_max_seqlen_q = max_seqlen_q + if isinstance(max_seqlen_k, dict): + self.graph_max_seqlen_k.update(max_seqlen_k) + else: + for layer_type in self.graph_max_seqlen_k: + self.graph_max_seqlen_k[layer_type] = max_seqlen_k + + def get_graph_signature(self) -> Hashable: + """Return the CUDA graph signature for the current batch. + + Continuous batching pads tensor inputs to static shapes for CUDA graph replay, but FA-style kernels also depend + on non-tensor integer kwargs such as `max_seqlen_q` and `max_seqlen_k`. Reusing a graph across batches that + share padded `(Q, KV)` sizes but differ on those integers can replay kernels with stale launch parameters. + """ + max_seqlen_k_signature = tuple( + ( + layer_type, + self.graph_max_seqlen_k[layer_type] if self.graph_max_seqlen_k[layer_type] > 0 else self.max_seqlen_k[layer_type], + ) + for layer_type in self.cumulative_seqlens_k.keys() + ) + return ( + "decode_fast_path" if self.use_block_table else "varlen", + self.num_q_tokens, + self.max_kv_read, + self.graph_max_seqlen_q if self.graph_max_seqlen_q > 0 else self.max_seqlen_q, + max_seqlen_k_signature, + ) + def get_graph(self) -> torch.cuda.CUDAGraph | None: - graph = self.graphs.get_graph(self.num_q_tokens, self.max_kv_read) + graph_signature = self.get_graph_signature() + graph = self.graphs.get_graph(graph_signature) # If this point is reached, it means the next step will be a new graph capture if graph is None: self.graphs.plan_for_new_graph() - logger.info(f"Creating graph for {(self.num_q_tokens, self.max_kv_read) = }") + logger.info(f"Creating graph for {graph_signature = }") return graph def set_graph(self, graph: torch.cuda.CUDAGraph) -> None: - self.graphs.set_graph(self.num_q_tokens, self.max_kv_read, graph) + self.graphs.set_graph(self.get_graph_signature(), graph) class HostDeviceIOPair: diff --git a/src/transformers/generation/continuous_batching/utils.py b/src/transformers/generation/continuous_batching/utils.py index a827d247110e..bd028f37c6e9 100644 --- a/src/transformers/generation/continuous_batching/utils.py +++ b/src/transformers/generation/continuous_batching/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict +from collections.abc import Hashable from math import ceil from typing import Any @@ -29,7 +30,7 @@ def __init__(self, max_size: int) -> None: if max_size <= 0: raise ValueError(f"max_size must be positive, but got {max_size}") self.max_size = max_size - self._storage: OrderedDict[tuple[int, int], torch.cuda.CUDAGraph] = OrderedDict() + self._storage: OrderedDict[Hashable, torch.cuda.CUDAGraph] = OrderedDict() def __del__(self) -> None: original_max_size = self.max_size @@ -37,10 +38,10 @@ def __del__(self) -> None: self.plan_for_new_graph(silent=True) self.max_size = original_max_size - def get_graph(self, q_len: int, kv_len: int) -> torch.cuda.CUDAGraph | None: - graph = self._storage.get((q_len, kv_len)) + def get_graph(self, key: Hashable) -> torch.cuda.CUDAGraph | None: + graph = self._storage.get(key) if graph is not None: - self._storage.move_to_end((q_len, kv_len)) + self._storage.move_to_end(key) return graph def plan_for_new_graph(self, silent: bool = False) -> None: @@ -50,11 +51,11 @@ def plan_for_new_graph(self, silent: bool = False) -> None: logger.info(f"Evicting graph for {evicted_key = }") evicted_graph.reset() - def set_graph(self, q_len: int, kv_len: int, graph: torch.cuda.CUDAGraph) -> None: + def set_graph(self, key: Hashable, graph: torch.cuda.CUDAGraph) -> None: # In our use case, this should not have any effect because we plan for a new graph before it is captured self.plan_for_new_graph() - logger.info(f"Setting graph for {q_len = }, {kv_len = }") - self._storage[(q_len, kv_len)] = graph + logger.info(f"Setting graph for {key = }") + self._storage[key] = graph def attn_mask_is_needed(config: PretrainedConfig) -> bool: diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index 98566466a060..a43e787dc181 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -40,7 +40,7 @@ ) from transformers.generation.continuous_batching.cache_manager import FullAttentionCacheAllocator from transformers.generation.continuous_batching.continuous_api import ContinuousBatchProcessor, OutputRouter -from transformers.generation.continuous_batching.input_outputs import build_attention_mask +from transformers.generation.continuous_batching.input_outputs import ContinuousBatchingIOs, build_attention_mask from transformers.generation.continuous_batching.requests import GenerationOutput, RequestStatus from transformers.testing_utils import ( require_deterministic_for_xpu, @@ -202,6 +202,64 @@ def regular_generate( # Class for all continuous batching tests that do not require any accelerator. Usualy those test are faster to run. class ContinuousBatchingNoAcceleratorTest(unittest.TestCase): + def test_cuda_graph_signature_tracks_non_tensor_runtime_args(self) -> None: + """CUDA graph reuse must distinguish batches that share padded tensor sizes but not FA runtime ints.""" + + class _DummyLogitsProcessor: + tensors_required = 0 + + def fill_defaults(self, arg_storage: torch.Tensor) -> None: + return None + + def prepare_tensor_args(self, requests_in_batch: list[Any], arg_storage: torch.Tensor) -> None: + return None + + config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B", attn_implementation="flash_attention_2") + cache = PagedAttentionCache( + config=config, + continuous_batching_config=ContinuousBatchingConfig(block_size=16, num_blocks=8, max_batch_tokens=32), + device="cpu", + ) + io = ContinuousBatchingIOs( + cache=cache, + config=config, + device=torch.device("cpu"), + model_dtype=torch.float16, + max_graphs=2, + return_logprobs=False, + logit_processor=_DummyLogitsProcessor(), + ) + + io.use_block_table = False + io.num_q_tokens = 16 + io.max_kv_read = 32 + io.max_seqlen_q = 7 + io.max_seqlen_k = dict.fromkeys(io.max_seqlen_k, 19) + base_signature = io.get_graph_signature() + + io.max_seqlen_q = 11 + self.assertNotEqual( + base_signature, + io.get_graph_signature(), + "Graph signature should change when max_seqlen_q changes under the same padded Q/KV sizes.", + ) + + io.max_seqlen_q = 7 + io.max_seqlen_k = dict.fromkeys(io.max_seqlen_k, 23) + self.assertNotEqual( + base_signature, + io.get_graph_signature(), + "Graph signature should change when max_seqlen_k changes under the same padded Q/KV sizes.", + ) + + io.max_seqlen_k = dict.fromkeys(io.max_seqlen_k, 19) + io.use_block_table = True + self.assertNotEqual( + base_signature, + io.get_graph_signature(), + "Graph signature should change when switching between varlen and decode-fast-path batches.", + ) + @parameterized.expand( [ (None, None, "0"), From a7c90440e89aabbaf3a0dbbf9976163d769bc701 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 6 Apr 2026 22:44:11 +0000 Subject: [PATCH 02/10] fix async graph bounds fallback --- .../continuous_batching/input_outputs.py | 23 ++++ tests/generation/test_continuous_batching.py | 103 +++++++++++++++++- 2 files changed, 125 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/continuous_batching/input_outputs.py b/src/transformers/generation/continuous_batching/input_outputs.py index 281ea00e5b93..d26ab355fb29 100644 --- a/src/transformers/generation/continuous_batching/input_outputs.py +++ b/src/transformers/generation/continuous_batching/input_outputs.py @@ -244,6 +244,8 @@ def _transfer_inputs( other.total_seqlen_q = self.total_seqlen_q other.max_seqlen_q = self.max_seqlen_q other.max_seqlen_k = dict(self.max_seqlen_k.items()) + other.graph_max_seqlen_q = self.graph_max_seqlen_q + other.graph_max_seqlen_k = dict(self.graph_max_seqlen_k.items()) # Transfer static tensors maybe_stream = torch.cuda.stream(stream) if stream is not None else nullcontext() with maybe_stream: @@ -711,6 +713,22 @@ def __init__( def get_cumulative_seqlens(self) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: return self.io_pairs[self.current_pair].host_io.get_cumulative_seqlens() + @property + def num_q_tokens(self) -> int: + return self.io_pairs[self.current_pair].host_io.num_q_tokens + + @property + def max_kv_read(self) -> int: + return self.io_pairs[self.current_pair].host_io.max_kv_read + + @property + def max_seqlen_q(self) -> int: + return self.io_pairs[self.current_pair].host_io.max_seqlen_q + + @property + def max_seqlen_k(self) -> dict[str, int]: + return self.io_pairs[self.current_pair].host_io.max_seqlen_k + # The prepare_batch_tensor method also has to prepare the carry over ids def prepare_batch_tensors( self, @@ -769,6 +787,11 @@ def get_cb_kwargs(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: current_pair.device_io.output_ids, ) + def set_graph_bounds(self, max_seqlen_q: int, max_seqlen_k: int | dict[str, int]) -> None: + io_pair = self.io_pairs[self.current_pair] + io_pair.host_io.set_graph_bounds(max_seqlen_q, max_seqlen_k) + io_pair.device_io.set_graph_bounds(max_seqlen_q, max_seqlen_k) + def carry_over_tokens( self, input_ids: torch.Tensor, carry_over_ids: torch.Tensor, prev_output_ids: torch.Tensor ) -> None: diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index a43e787dc181..580db7710e42 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -40,7 +40,11 @@ ) from transformers.generation.continuous_batching.cache_manager import FullAttentionCacheAllocator from transformers.generation.continuous_batching.continuous_api import ContinuousBatchProcessor, OutputRouter -from transformers.generation.continuous_batching.input_outputs import ContinuousBatchingIOs, build_attention_mask +from transformers.generation.continuous_batching.input_outputs import ( + ContinuousBatchingAsyncIOs, + ContinuousBatchingIOs, + build_attention_mask, +) from transformers.generation.continuous_batching.requests import GenerationOutput, RequestStatus from transformers.testing_utils import ( require_deterministic_for_xpu, @@ -260,6 +264,103 @@ def prepare_tensor_args(self, requests_in_batch: list[Any], arg_storage: torch.T "Graph signature should change when switching between varlen and decode-fast-path batches.", ) + def test_transfer_inputs_copies_graph_bounds(self) -> None: + """Host-to-device IO transfers must keep graph-bound metadata in sync with tensor inputs.""" + + class _DummyLogitsProcessor: + tensors_required = 0 + + def fill_defaults(self, arg_storage: torch.Tensor) -> None: + return None + + def prepare_tensor_args(self, requests_in_batch: list[Any], arg_storage: torch.Tensor) -> None: + return None + + config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B", attn_implementation="flash_attention_2") + cache = PagedAttentionCache( + config=config, + continuous_batching_config=ContinuousBatchingConfig(block_size=16, num_blocks=8, max_batch_tokens=32), + device="cpu", + ) + source = ContinuousBatchingIOs( + cache=cache, + config=config, + device=torch.device("cpu"), + model_dtype=torch.float16, + max_graphs=2, + return_logprobs=False, + logit_processor=_DummyLogitsProcessor(), + ) + target = ContinuousBatchingIOs( + cache=cache, + config=config, + device=torch.device("cpu"), + model_dtype=torch.float16, + max_graphs=2, + return_logprobs=False, + logit_processor=_DummyLogitsProcessor(), + ) + + source.num_q_tokens = 16 + source.max_kv_read = 32 + source.max_seqlen_q = 7 + source.max_seqlen_k = dict.fromkeys(source.max_seqlen_k, 19) + source.set_graph_bounds(11, 23) + + source._transfer_inputs(target, stream=None) + + self.assertEqual(target.graph_max_seqlen_q, 11) + self.assertEqual(target.graph_max_seqlen_k, dict.fromkeys(target.graph_max_seqlen_k, 23)) + + +@require_torch_gpu +class ContinuousBatchingAsyncIOTest(unittest.TestCase): + def test_async_io_exposes_graph_bounds_interface(self) -> None: + """Async IOs should expose the same graph-bound interface as sync IOs.""" + + class _DummyLogitsProcessor: + tensors_required = 0 + + def fill_defaults(self, arg_storage: torch.Tensor) -> None: + return None + + def prepare_tensor_args(self, requests_in_batch: list[Any], arg_storage: torch.Tensor) -> None: + return None + + config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B", attn_implementation="flash_attention_2") + cache = PagedAttentionCache( + config=config, + continuous_batching_config=ContinuousBatchingConfig(block_size=16, num_blocks=8, max_batch_tokens=32), + device=torch_device, + ) + io = ContinuousBatchingAsyncIOs( + cache=cache, + config=config, + device=torch.device(torch_device), + model_dtype=torch.float16, + max_graphs=2, + return_logprobs=False, + logit_processor=_DummyLogitsProcessor(), + ) + + self.assertEqual(io.num_q_tokens, 0) + self.assertEqual(io.max_kv_read, 0) + self.assertEqual(io.max_seqlen_q, 0) + self.assertEqual(io.max_seqlen_k, dict.fromkeys(io.max_seqlen_k, 0)) + + io.set_graph_bounds(13, 29) + + current_pair = io.io_pairs[io.current_pair] + self.assertEqual(current_pair.host_io.graph_max_seqlen_q, 13) + self.assertEqual(current_pair.device_io.graph_max_seqlen_q, 13) + self.assertEqual(current_pair.host_io.graph_max_seqlen_k, dict.fromkeys(current_pair.host_io.graph_max_seqlen_k, 29)) + self.assertEqual( + current_pair.device_io.graph_max_seqlen_k, + dict.fromkeys(current_pair.device_io.graph_max_seqlen_k, 29), + ) + + io.reset() + @parameterized.expand( [ (None, None, "0"), From a2a95fea250a7ff8550534d3ef78672f320d1be4 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 6 Apr 2026 23:22:45 +0000 Subject: [PATCH 03/10] guard fa2 async graphs --- .../continuous_batching/continuous_api.py | 35 ++++++++++ tests/generation/test_continuous_batching.py | 66 ++++++++++++++++++- 2 files changed, 100 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 5709fc9ccb05..66efd4985790 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -64,6 +64,27 @@ All defaults are stored in ContinuousBatchingConfig.resolve_sentinel_values(). """ +_ASYNC_CUDAGRAPH_FA2_MAX_BATCH_TOKENS = 768 + + +def _should_disable_async_batching_for_large_fa2_graphs( + *, + config: PretrainedConfig, + model_device: torch.device, + use_cuda_graph: bool, + use_async_batching: bool, + max_batch_tokens: int, +) -> bool: + """Return whether FA2 async+graph batching should be downgraded for large resolved batch-token budgets.""" + + return ( + model_device.type == "cuda" + and use_cuda_graph + and use_async_batching + and max_batch_tokens > _ASYNC_CUDAGRAPH_FA2_MAX_BATCH_TOKENS + and is_flash_attention_requested(config=config, version=2) + ) + # We cannot use `PreTrainedModel` for circular import reasons, so this helps keep track of the basic types class ProtoPretrainedModel(nn.Module): @@ -189,6 +210,20 @@ def __init__( self.max_batch_tokens = cache.max_batch_tokens self.metrics = ContinuousBatchProcessorMetrics(cache.max_batch_tokens) + if _should_disable_async_batching_for_large_fa2_graphs( + config=self.config, + model_device=self.model_device, + use_cuda_graph=self.use_cuda_graph, + use_async_batching=self.cb_config.use_async_batching, + max_batch_tokens=self.max_batch_tokens, + ): + logger.warning( + "Disabling async batching for flash_attention_2 continuous batching because " + f"{self.max_batch_tokens = } exceeds the validated async+graph limit of " + f"{_ASYNC_CUDAGRAPH_FA2_MAX_BATCH_TOKENS} tokens." + ) + self.cb_config.use_async_batching = False + # If the user turned on the decode fast path (ie. using a block table), check if it is available self._ensure_decode_fast_path_is_available() # this needs to happen before self.inputs_and_outputs is created diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index 580db7710e42..f52a815f7708 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -39,7 +39,11 @@ group_layers_by_attn_type, ) from transformers.generation.continuous_batching.cache_manager import FullAttentionCacheAllocator -from transformers.generation.continuous_batching.continuous_api import ContinuousBatchProcessor, OutputRouter +from transformers.generation.continuous_batching.continuous_api import ( + ContinuousBatchProcessor, + OutputRouter, + _should_disable_async_batching_for_large_fa2_graphs, +) from transformers.generation.continuous_batching.input_outputs import ( ContinuousBatchingAsyncIOs, ContinuousBatchingIOs, @@ -206,6 +210,66 @@ def regular_generate( # Class for all continuous batching tests that do not require any accelerator. Usualy those test are faster to run. class ContinuousBatchingNoAcceleratorTest(unittest.TestCase): + def test_async_graph_guard_only_trips_for_large_fa2_cuda_batches(self) -> None: + fa2_config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B", attn_implementation="flash_attention_2") + sdpa_config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B", attn_implementation="sdpa") + cuda_device = torch.device("cuda") + + self.assertTrue( + _should_disable_async_batching_for_large_fa2_graphs( + config=fa2_config, + model_device=cuda_device, + use_cuda_graph=True, + use_async_batching=True, + max_batch_tokens=1024, + ) + ) + self.assertFalse( + _should_disable_async_batching_for_large_fa2_graphs( + config=fa2_config, + model_device=cuda_device, + use_cuda_graph=True, + use_async_batching=True, + max_batch_tokens=768, + ) + ) + self.assertFalse( + _should_disable_async_batching_for_large_fa2_graphs( + config=fa2_config, + model_device=torch.device("cpu"), + use_cuda_graph=True, + use_async_batching=True, + max_batch_tokens=1024, + ) + ) + self.assertFalse( + _should_disable_async_batching_for_large_fa2_graphs( + config=fa2_config, + model_device=cuda_device, + use_cuda_graph=False, + use_async_batching=True, + max_batch_tokens=1024, + ) + ) + self.assertFalse( + _should_disable_async_batching_for_large_fa2_graphs( + config=fa2_config, + model_device=cuda_device, + use_cuda_graph=True, + use_async_batching=False, + max_batch_tokens=1024, + ) + ) + self.assertFalse( + _should_disable_async_batching_for_large_fa2_graphs( + config=sdpa_config, + model_device=cuda_device, + use_cuda_graph=True, + use_async_batching=True, + max_batch_tokens=1024, + ) + ) + def test_cuda_graph_signature_tracks_non_tensor_runtime_args(self) -> None: """CUDA graph reuse must distinguish batches that share padded tensor sizes but not FA runtime ints.""" From c21271a7c48bc5cc0aec19f0c066b06b50d5c4ed Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 6 Apr 2026 23:29:22 +0000 Subject: [PATCH 04/10] Revert "guard fa2 async graphs" This reverts commit a2a95fea250a7ff8550534d3ef78672f320d1be4. --- .../continuous_batching/continuous_api.py | 35 ---------- tests/generation/test_continuous_batching.py | 66 +------------------ 2 files changed, 1 insertion(+), 100 deletions(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 66efd4985790..5709fc9ccb05 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -64,27 +64,6 @@ All defaults are stored in ContinuousBatchingConfig.resolve_sentinel_values(). """ -_ASYNC_CUDAGRAPH_FA2_MAX_BATCH_TOKENS = 768 - - -def _should_disable_async_batching_for_large_fa2_graphs( - *, - config: PretrainedConfig, - model_device: torch.device, - use_cuda_graph: bool, - use_async_batching: bool, - max_batch_tokens: int, -) -> bool: - """Return whether FA2 async+graph batching should be downgraded for large resolved batch-token budgets.""" - - return ( - model_device.type == "cuda" - and use_cuda_graph - and use_async_batching - and max_batch_tokens > _ASYNC_CUDAGRAPH_FA2_MAX_BATCH_TOKENS - and is_flash_attention_requested(config=config, version=2) - ) - # We cannot use `PreTrainedModel` for circular import reasons, so this helps keep track of the basic types class ProtoPretrainedModel(nn.Module): @@ -210,20 +189,6 @@ def __init__( self.max_batch_tokens = cache.max_batch_tokens self.metrics = ContinuousBatchProcessorMetrics(cache.max_batch_tokens) - if _should_disable_async_batching_for_large_fa2_graphs( - config=self.config, - model_device=self.model_device, - use_cuda_graph=self.use_cuda_graph, - use_async_batching=self.cb_config.use_async_batching, - max_batch_tokens=self.max_batch_tokens, - ): - logger.warning( - "Disabling async batching for flash_attention_2 continuous batching because " - f"{self.max_batch_tokens = } exceeds the validated async+graph limit of " - f"{_ASYNC_CUDAGRAPH_FA2_MAX_BATCH_TOKENS} tokens." - ) - self.cb_config.use_async_batching = False - # If the user turned on the decode fast path (ie. using a block table), check if it is available self._ensure_decode_fast_path_is_available() # this needs to happen before self.inputs_and_outputs is created diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index f52a815f7708..580db7710e42 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -39,11 +39,7 @@ group_layers_by_attn_type, ) from transformers.generation.continuous_batching.cache_manager import FullAttentionCacheAllocator -from transformers.generation.continuous_batching.continuous_api import ( - ContinuousBatchProcessor, - OutputRouter, - _should_disable_async_batching_for_large_fa2_graphs, -) +from transformers.generation.continuous_batching.continuous_api import ContinuousBatchProcessor, OutputRouter from transformers.generation.continuous_batching.input_outputs import ( ContinuousBatchingAsyncIOs, ContinuousBatchingIOs, @@ -210,66 +206,6 @@ def regular_generate( # Class for all continuous batching tests that do not require any accelerator. Usualy those test are faster to run. class ContinuousBatchingNoAcceleratorTest(unittest.TestCase): - def test_async_graph_guard_only_trips_for_large_fa2_cuda_batches(self) -> None: - fa2_config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B", attn_implementation="flash_attention_2") - sdpa_config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B", attn_implementation="sdpa") - cuda_device = torch.device("cuda") - - self.assertTrue( - _should_disable_async_batching_for_large_fa2_graphs( - config=fa2_config, - model_device=cuda_device, - use_cuda_graph=True, - use_async_batching=True, - max_batch_tokens=1024, - ) - ) - self.assertFalse( - _should_disable_async_batching_for_large_fa2_graphs( - config=fa2_config, - model_device=cuda_device, - use_cuda_graph=True, - use_async_batching=True, - max_batch_tokens=768, - ) - ) - self.assertFalse( - _should_disable_async_batching_for_large_fa2_graphs( - config=fa2_config, - model_device=torch.device("cpu"), - use_cuda_graph=True, - use_async_batching=True, - max_batch_tokens=1024, - ) - ) - self.assertFalse( - _should_disable_async_batching_for_large_fa2_graphs( - config=fa2_config, - model_device=cuda_device, - use_cuda_graph=False, - use_async_batching=True, - max_batch_tokens=1024, - ) - ) - self.assertFalse( - _should_disable_async_batching_for_large_fa2_graphs( - config=fa2_config, - model_device=cuda_device, - use_cuda_graph=True, - use_async_batching=False, - max_batch_tokens=1024, - ) - ) - self.assertFalse( - _should_disable_async_batching_for_large_fa2_graphs( - config=sdpa_config, - model_device=cuda_device, - use_cuda_graph=True, - use_async_batching=True, - max_batch_tokens=1024, - ) - ) - def test_cuda_graph_signature_tracks_non_tensor_runtime_args(self) -> None: """CUDA graph reuse must distinguish batches that share padded tensor sizes but not FA runtime ints.""" From c74a4f88b9df3e1af1177308dabbea11541e6d0d Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 6 Apr 2026 23:38:33 +0000 Subject: [PATCH 05/10] document graph bound replay --- .../generation/continuous_batching/continuous_api.py | 3 +++ .../generation/continuous_batching/input_outputs.py | 4 ++++ src/transformers/generation/continuous_batching/utils.py | 1 + 3 files changed, 8 insertions(+) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 5709fc9ccb05..3c2d16cef9f8 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -393,8 +393,11 @@ def prepare_next_batch(self) -> bool: ) if self._pad_inputs: if self.inputs_and_outputs.use_block_table: + # Block-table decode ignores the varlen max_seqlen kwargs, so keep the graph signature constant. self.inputs_and_outputs.set_graph_bounds(1, 1) else: + # FA varlen kernels also specialize on max_seqlen_* integer kwargs, so bucket those for graph replay + # separately from the already-padded tensor shapes. padded_max_seqlen_q = pad_to_interval( self.inputs_and_outputs.max_seqlen_q, self.q_padding_interval_size, self.inputs_and_outputs.num_q_tokens ) diff --git a/src/transformers/generation/continuous_batching/input_outputs.py b/src/transformers/generation/continuous_batching/input_outputs.py index d26ab355fb29..767176681745 100644 --- a/src/transformers/generation/continuous_batching/input_outputs.py +++ b/src/transformers/generation/continuous_batching/input_outputs.py @@ -244,6 +244,7 @@ def _transfer_inputs( other.total_seqlen_q = self.total_seqlen_q other.max_seqlen_q = self.max_seqlen_q other.max_seqlen_k = dict(self.max_seqlen_k.items()) + # Graph captures depend on these padded scalar bounds too, so async transfers must mirror them with tensors. other.graph_max_seqlen_q = self.graph_max_seqlen_q other.graph_max_seqlen_k = dict(self.graph_max_seqlen_k.items()) # Transfer static tensors @@ -551,6 +552,7 @@ def get_cb_kwargs(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.carry_over_ids, self.output_ids, self.output_ids def set_graph_bounds(self, max_seqlen_q: int, max_seqlen_k: int | dict[str, int]) -> None: + """Store the padded max_seqlen kwargs that the next graph capture or replay must see.""" self.graph_max_seqlen_q = max_seqlen_q if isinstance(max_seqlen_k, dict): self.graph_max_seqlen_k.update(max_seqlen_k) @@ -565,6 +567,7 @@ def get_graph_signature(self) -> Hashable: on non-tensor integer kwargs such as `max_seqlen_q` and `max_seqlen_k`. Reusing a graph across batches that share padded `(Q, KV)` sizes but differ on those integers can replay kernels with stale launch parameters. """ + # Hybrid models can have per-layer KV limits, so the signature has to encode them explicitly. max_seqlen_k_signature = tuple( ( layer_type, @@ -788,6 +791,7 @@ def get_cb_kwargs(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) def set_graph_bounds(self, max_seqlen_q: int, max_seqlen_k: int | dict[str, int]) -> None: + # Keep the host and device pair on the same graph signature before async H2D transfer starts. io_pair = self.io_pairs[self.current_pair] io_pair.host_io.set_graph_bounds(max_seqlen_q, max_seqlen_k) io_pair.device_io.set_graph_bounds(max_seqlen_q, max_seqlen_k) diff --git a/src/transformers/generation/continuous_batching/utils.py b/src/transformers/generation/continuous_batching/utils.py index bd028f37c6e9..6c0cf5312a5e 100644 --- a/src/transformers/generation/continuous_batching/utils.py +++ b/src/transformers/generation/continuous_batching/utils.py @@ -39,6 +39,7 @@ def __del__(self) -> None: self.max_size = original_max_size def get_graph(self, key: Hashable) -> torch.cuda.CUDAGraph | None: + # Replay safety sometimes depends on scalar kwargs as well as tensor extents, so the key is generic. graph = self._storage.get(key) if graph is not None: self._storage.move_to_end(key) From 33e9cbb08e858cc2d01268aff01dd3a87f83ec2b Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 6 Apr 2026 23:50:39 +0000 Subject: [PATCH 06/10] simplify graph reuse fix --- .../continuous_batching/continuous_api.py | 35 ++--- .../continuous_batching/input_outputs.py | 144 +++++++++--------- .../generation/continuous_batching/utils.py | 10 +- tests/generation/test_continuous_batching.py | 118 +------------- 4 files changed, 98 insertions(+), 209 deletions(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 3c2d16cef9f8..14572d3b648a 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -391,25 +391,6 @@ def prepare_next_batch(self) -> bool: self.inputs_and_outputs.prepare_batch_tensors( requests_in_batch, self.logit_processor, use_decode_fast_path, num_q_tokens, max_kv_read ) - if self._pad_inputs: - if self.inputs_and_outputs.use_block_table: - # Block-table decode ignores the varlen max_seqlen kwargs, so keep the graph signature constant. - self.inputs_and_outputs.set_graph_bounds(1, 1) - else: - # FA varlen kernels also specialize on max_seqlen_* integer kwargs, so bucket those for graph replay - # separately from the already-padded tensor shapes. - padded_max_seqlen_q = pad_to_interval( - self.inputs_and_outputs.max_seqlen_q, self.q_padding_interval_size, self.inputs_and_outputs.num_q_tokens - ) - padded_max_seqlen_k = { - layer_type: pad_to_interval( - self.inputs_and_outputs.max_seqlen_k[layer_type], - self.kv_padding_interval_size, - self.inputs_and_outputs.max_kv_read + self.inputs_and_outputs.num_q_tokens, - ) - for layer_type in self.inputs_and_outputs.max_seqlen_k - } - self.inputs_and_outputs.set_graph_bounds(padded_max_seqlen_q, padded_max_seqlen_k) self.metrics.record_kv_cache_memory_metrics(self.cache) return True @@ -541,7 +522,9 @@ def _generation_step(self, model: nn.Module) -> None: # Otherwise, we use create or replay the graph (cuda is available in this path) else: - graph = self.inputs_and_outputs.get_graph() + graph = self.inputs_and_outputs.get_graph( + batch_data, self.q_padding_interval_size, self.kv_padding_interval_size + ) # Case: the graph already exists, so we replay it if graph is not None: with torch.cuda.stream(compute_stream): @@ -549,12 +532,12 @@ def _generation_step(self, model: nn.Module) -> None: # Otherwise, the graph does not exist, so we create it else: args = (model, batch_data, carry_over_ids, prev_output_ids, output_ids) - self.capture_graph(forward_fn, compute_stream, *args) + self.capture_graph(forward_fn, compute_stream, batch_data, *args) # In any case, we transfer the outputs to the host self.inputs_and_outputs.retrieve_device_outputs() - def capture_graph(self, forward_fn: Any, compute_stream: torch.cuda.Stream, *args) -> None: + def capture_graph(self, forward_fn: Any, compute_stream: torch.cuda.Stream, batch_data: dict, *args) -> None: # Warmup (ensures the right result is computed before capturing the graph) with torch.cuda.stream(compute_stream): forward_fn(*args) @@ -566,7 +549,9 @@ def capture_graph(self, forward_fn: Any, compute_stream: torch.cuda.Stream, *arg with torch.cuda.graph(graph, stream=compute_stream, pool=self.graph_pool, capture_error_mode="thread_local"): forward_fn(*args) # Store - self.inputs_and_outputs.set_graph(graph) + self.inputs_and_outputs.set_graph( + graph, batch_data, self.q_padding_interval_size, self.kv_padding_interval_size + ) @traced def _forward_process_and_sample( @@ -685,7 +670,7 @@ def warmup( forward_fn = self._compiled_varlen or self._forward_process_and_sample forward_fn_args = (model, batch_data, carry_over_ids, prev_output_ids, output_ids) if self.use_cuda_graph: - self.capture_graph(forward_fn, compute_stream, *forward_fn_args) + self.capture_graph(forward_fn, compute_stream, batch_data, *forward_fn_args) else: with torch.cuda.stream(compute_stream): forward_fn(*forward_fn_args) @@ -723,7 +708,7 @@ def warmup( forward_fn = self._compiled_decode or self._forward_process_and_sample forward_fn_args = (model, batch_data, carry_over_ids, prev_output_ids, output_ids) if self.use_cuda_graph: - self.capture_graph(forward_fn, compute_stream, *forward_fn_args) + self.capture_graph(forward_fn, compute_stream, batch_data, *forward_fn_args) else: with torch.cuda.stream(compute_stream): forward_fn(*forward_fn_args) diff --git a/src/transformers/generation/continuous_batching/input_outputs.py b/src/transformers/generation/continuous_batching/input_outputs.py index 767176681745..471299f68e54 100644 --- a/src/transformers/generation/continuous_batching/input_outputs.py +++ b/src/transformers/generation/continuous_batching/input_outputs.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Hashable from contextlib import nullcontext from dataclasses import dataclass from functools import partial @@ -27,7 +26,7 @@ from .cache import PagedAttentionCache from .cb_logits_processors import ContinuousBatchingLogitsProcessorList from .requests import TMP_TOKEN_ID, FutureRequestState, logger -from .utils import CudaGraphBuffer, aligned_divide, attn_mask_is_needed, build_attention_mask +from .utils import CudaGraphBuffer, aligned_divide, attn_mask_is_needed, build_attention_mask, pad_to_interval @dataclass @@ -129,8 +128,6 @@ def __init__( self.true_read_sizes = [0 for _ in range(cache.num_groups)] self.true_write_sizes = [0 for _ in range(cache.num_groups)] self.use_block_table = False # True if all requests in batch have query_length == 1 - self.graph_max_seqlen_q = 0 - self.graph_max_seqlen_k: dict[str, int] = {} # Setup other accumulators self.requests_in_batch: list[FutureRequestState] = [] self.req_id_to_new_token_position: dict[str, int] = {} # only used for async API @@ -197,7 +194,6 @@ def _setup_static_tensors(self, logit_processor: ContinuousBatchingLogitsProcess self.total_seqlen_q = 0 self.max_seqlen_q = 0 self.max_seqlen_k = dict.fromkeys(self.cumulative_seqlens_k.keys(), 0) - self.graph_max_seqlen_k = dict.fromkeys(self.cumulative_seqlens_k.keys(), 0) # If the attention mask is needed, it is allocated separately if attn_mask_is_needed(self.config): @@ -244,9 +240,6 @@ def _transfer_inputs( other.total_seqlen_q = self.total_seqlen_q other.max_seqlen_q = self.max_seqlen_q other.max_seqlen_k = dict(self.max_seqlen_k.items()) - # Graph captures depend on these padded scalar bounds too, so async transfers must mirror them with tensors. - other.graph_max_seqlen_q = self.graph_max_seqlen_q - other.graph_max_seqlen_k = dict(self.graph_max_seqlen_k.items()) # Transfer static tensors maybe_stream = torch.cuda.stream(stream) if stream is not None else nullcontext() with maybe_stream: @@ -283,12 +276,10 @@ def _reset_static_tensors(self, full_reset: bool = False) -> None: # Reset the logits indices and output ids self.logits_indices[:q_len].zero_() self.output_ids[:, :q_len].zero_() - self.graph_max_seqlen_q = 0 # Reset the attributes that are either tensors or dict of tensors for layer_type in self.cumulative_seqlens_k: self.max_seqlen_k[layer_type] = 0 - self.graph_max_seqlen_k[layer_type] = 0 if self.attention_mask is not None: self.attention_mask[layer_type][:, :, :q_len, : q_len + kv_len].fill_( torch.finfo(self.model_dtype).min @@ -469,14 +460,13 @@ def get_model_kwargs(self, use_padding: bool = False) -> dict[str, Any]: q_size = self.num_q_tokens kv_size = self.max_kv_read + self.num_q_tokens batch_size = self.num_q_tokens if use_padding else self.true_batch_size - effective_max_seqlen_q = self.graph_max_seqlen_q if use_padding and self.graph_max_seqlen_q > 0 else self.max_seqlen_q # Prepare the kwargs, the attributes that are either tensors or dict of tensors are initialized to empty dicts. kwargs = PagedAttentionArgs( input_ids=self.input_ids[:q_size].unsqueeze(0), position_ids=self.position_ids[:q_size].unsqueeze(0), cu_seq_lens_q=self.cumulative_seqlens_q[: batch_size + 1], - max_seqlen_q=effective_max_seqlen_q, + max_seqlen_q=self.max_seqlen_q, logits_indices=self.logits_indices[:q_size], logits_processor_args=self._bulk_input_tensor[self.static_inputs :, :q_size], cu_seq_lens_k={}, @@ -493,8 +483,7 @@ def get_model_kwargs(self, use_padding: bool = False) -> dict[str, Any]: # some models like Qwen3-4B-Instruct-2507, if we don't include these tokens in cumulative_seqlens_q, there are # some NaNs in the output logits even for non-padded tokens. if use_padding: - effective_max_seqlen_q = max(effective_max_seqlen_q, q_size - self.total_seqlen_q) - kwargs.max_seqlen_q = effective_max_seqlen_q + kwargs.max_seqlen_q = max(kwargs.max_seqlen_q, q_size - self.total_seqlen_q) self.cumulative_seqlens_q[self.true_batch_size + 1 :] = q_size # FIXME: is there another way to avoid this? It has a very slight impact on performance (~5 tok/s) @@ -519,24 +508,14 @@ def get_model_kwargs(self, use_padding: bool = False) -> dict[str, Any]: kwargs.attention_mask: dict[str, torch.Tensor] = {} for layer_type, seqlens_k in self.cumulative_seqlens_k.items(): kwargs.cu_seq_lens_k[layer_type] = seqlens_k[: batch_size + 1] - graph_max_seqlen_k = ( - self.graph_max_seqlen_k[layer_type] - if use_padding and self.graph_max_seqlen_k[layer_type] > 0 - else self.max_seqlen_k[layer_type] - ) - kwargs.max_seqlen_k[layer_type] = 1 if self.use_block_table else graph_max_seqlen_k + kwargs.max_seqlen_k[layer_type] = 1 if self.use_block_table else self.max_seqlen_k[layer_type] if self.attention_mask is not None: k_len = kv_size if use_padding else seqlens_k[batch_size] kwargs.attention_mask[layer_type] = self.attention_mask[layer_type][..., :q_size, :k_len] else: layer_type = layer_types[0] kwargs.cu_seq_lens_k = self.cumulative_seqlens_k[layer_type][: batch_size + 1] - graph_max_seqlen_k = ( - self.graph_max_seqlen_k[layer_type] - if use_padding and self.graph_max_seqlen_k[layer_type] > 0 - else self.max_seqlen_k[layer_type] - ) - kwargs.max_seqlen_k = 1 if self.use_block_table else graph_max_seqlen_k + kwargs.max_seqlen_k = 1 if self.use_block_table else self.max_seqlen_k[layer_type] if self.attention_mask is not None: k_len = kv_size if use_padding else self.cumulative_seqlens_k[layer_type][batch_size] kwargs.attention_mask = self.attention_mask[layer_type][..., :q_size, :k_len] @@ -551,40 +530,60 @@ def get_cb_kwargs(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return 3 tensors to have the same interface as when using async batching.""" return self.carry_over_ids, self.output_ids, self.output_ids - def set_graph_bounds(self, max_seqlen_q: int, max_seqlen_k: int | dict[str, int]) -> None: - """Store the padded max_seqlen kwargs that the next graph capture or replay must see.""" - self.graph_max_seqlen_q = max_seqlen_q - if isinstance(max_seqlen_k, dict): - self.graph_max_seqlen_k.update(max_seqlen_k) - else: - for layer_type in self.graph_max_seqlen_k: - self.graph_max_seqlen_k[layer_type] = max_seqlen_k - - def get_graph_signature(self) -> Hashable: + def get_graph_signature( + self, + batch_data: dict[str, Any], + q_padding_interval_size: int = 0, + kv_padding_interval_size: int = 0, + ) -> tuple[Any, ...]: """Return the CUDA graph signature for the current batch. Continuous batching pads tensor inputs to static shapes for CUDA graph replay, but FA-style kernels also depend on non-tensor integer kwargs such as `max_seqlen_q` and `max_seqlen_k`. Reusing a graph across batches that share padded `(Q, KV)` sizes but differ on those integers can replay kernels with stale launch parameters. """ - # Hybrid models can have per-layer KV limits, so the signature has to encode them explicitly. - max_seqlen_k_signature = tuple( - ( - layer_type, - self.graph_max_seqlen_k[layer_type] if self.graph_max_seqlen_k[layer_type] > 0 else self.max_seqlen_k[layer_type], + q_tokens = self.num_q_tokens + max_kv_read = self.max_kv_read + max_seqlen_q = batch_data["max_seqlen_q"] + max_seqlen_k = batch_data["max_seqlen_k"] + if q_padding_interval_size > 0 and not self.use_block_table: + max_seqlen_q = pad_to_interval(max_seqlen_q, q_padding_interval_size, q_tokens) + max_kv_total = max_kv_read + q_tokens + if self.use_block_table: + max_seqlen_k_signature = tuple((layer_type, 1) for layer_type in max_seqlen_k) if isinstance(max_seqlen_k, dict) else 1 + else: + max_seqlen_k_signature = ( + tuple( + ( + layer_type, + pad_to_interval(max_seqlen_k[layer_type], kv_padding_interval_size, max_kv_total) + if kv_padding_interval_size > 0 + else max_seqlen_k[layer_type], + ) + for layer_type in max_seqlen_k + ) + if isinstance(max_seqlen_k, dict) + else ( + pad_to_interval(max_seqlen_k, kv_padding_interval_size, max_kv_total) + if kv_padding_interval_size > 0 + else max_seqlen_k + ) ) - for layer_type in self.cumulative_seqlens_k.keys() - ) return ( "decode_fast_path" if self.use_block_table else "varlen", - self.num_q_tokens, - self.max_kv_read, - self.graph_max_seqlen_q if self.graph_max_seqlen_q > 0 else self.max_seqlen_q, + q_tokens, + max_kv_read, + 1 if self.use_block_table else max_seqlen_q, max_seqlen_k_signature, ) - def get_graph(self) -> torch.cuda.CUDAGraph | None: - graph_signature = self.get_graph_signature() + def get_graph( + self, + batch_data: dict[str, Any], + q_padding_interval_size: int = 0, + kv_padding_interval_size: int = 0, + ) -> torch.cuda.CUDAGraph | None: + graph_signature = self.get_graph_signature(batch_data, q_padding_interval_size, kv_padding_interval_size) graph = self.graphs.get_graph(graph_signature) # If this point is reached, it means the next step will be a new graph capture if graph is None: @@ -592,8 +591,16 @@ def get_graph(self) -> torch.cuda.CUDAGraph | None: logger.info(f"Creating graph for {graph_signature = }") return graph - def set_graph(self, graph: torch.cuda.CUDAGraph) -> None: - self.graphs.set_graph(self.get_graph_signature(), graph) + def set_graph( + self, + graph: torch.cuda.CUDAGraph, + batch_data: dict[str, Any], + q_padding_interval_size: int = 0, + kv_padding_interval_size: int = 0, + ) -> None: + self.graphs.set_graph( + self.get_graph_signature(batch_data, q_padding_interval_size, kv_padding_interval_size), graph + ) class HostDeviceIOPair: @@ -724,14 +731,6 @@ def num_q_tokens(self) -> int: def max_kv_read(self) -> int: return self.io_pairs[self.current_pair].host_io.max_kv_read - @property - def max_seqlen_q(self) -> int: - return self.io_pairs[self.current_pair].host_io.max_seqlen_q - - @property - def max_seqlen_k(self) -> dict[str, int]: - return self.io_pairs[self.current_pair].host_io.max_seqlen_k - # The prepare_batch_tensor method also has to prepare the carry over ids def prepare_batch_tensors( self, @@ -790,12 +789,6 @@ def get_cb_kwargs(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: current_pair.device_io.output_ids, ) - def set_graph_bounds(self, max_seqlen_q: int, max_seqlen_k: int | dict[str, int]) -> None: - # Keep the host and device pair on the same graph signature before async H2D transfer starts. - io_pair = self.io_pairs[self.current_pair] - io_pair.host_io.set_graph_bounds(max_seqlen_q, max_seqlen_k) - io_pair.device_io.set_graph_bounds(max_seqlen_q, max_seqlen_k) - def carry_over_tokens( self, input_ids: torch.Tensor, carry_over_ids: torch.Tensor, prev_output_ids: torch.Tensor ) -> None: @@ -817,11 +810,26 @@ def output_ids(self) -> torch.Tensor: # The output ids are used to copy_ the infered tokens: they need to be on the device return self.io_pairs[self.current_pair].device_io.output_ids - def get_graph(self) -> torch.cuda.CUDAGraph | None: - return self.io_pairs[self.current_pair].device_io.get_graph() + def get_graph( + self, + batch_data: dict[str, Any], + q_padding_interval_size: int = 0, + kv_padding_interval_size: int = 0, + ) -> torch.cuda.CUDAGraph | None: + return self.io_pairs[self.current_pair].device_io.get_graph( + batch_data, q_padding_interval_size, kv_padding_interval_size + ) - def set_graph(self, graph: torch.cuda.CUDAGraph) -> None: - self.io_pairs[self.current_pair].device_io.set_graph(graph) + def set_graph( + self, + graph: torch.cuda.CUDAGraph, + batch_data: dict[str, Any], + q_padding_interval_size: int = 0, + kv_padding_interval_size: int = 0, + ) -> None: + self.io_pairs[self.current_pair].device_io.set_graph( + graph, batch_data, q_padding_interval_size, kv_padding_interval_size + ) @property def use_block_table(self) -> bool: diff --git a/src/transformers/generation/continuous_batching/utils.py b/src/transformers/generation/continuous_batching/utils.py index 6c0cf5312a5e..6ccc881573b1 100644 --- a/src/transformers/generation/continuous_batching/utils.py +++ b/src/transformers/generation/continuous_batching/utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict -from collections.abc import Hashable from math import ceil from typing import Any @@ -23,6 +22,9 @@ from .requests import FutureRequestState, RequestState, RequestStatus, logger +GraphSignature = tuple[Any, ...] + + class CudaGraphBuffer: """A fixed-size dict for CUDA graphs with LRU eviction when full.""" @@ -30,7 +32,7 @@ def __init__(self, max_size: int) -> None: if max_size <= 0: raise ValueError(f"max_size must be positive, but got {max_size}") self.max_size = max_size - self._storage: OrderedDict[Hashable, torch.cuda.CUDAGraph] = OrderedDict() + self._storage: OrderedDict[GraphSignature, torch.cuda.CUDAGraph] = OrderedDict() def __del__(self) -> None: original_max_size = self.max_size @@ -38,7 +40,7 @@ def __del__(self) -> None: self.plan_for_new_graph(silent=True) self.max_size = original_max_size - def get_graph(self, key: Hashable) -> torch.cuda.CUDAGraph | None: + def get_graph(self, key: GraphSignature) -> torch.cuda.CUDAGraph | None: # Replay safety sometimes depends on scalar kwargs as well as tensor extents, so the key is generic. graph = self._storage.get(key) if graph is not None: @@ -52,7 +54,7 @@ def plan_for_new_graph(self, silent: bool = False) -> None: logger.info(f"Evicting graph for {evicted_key = }") evicted_graph.reset() - def set_graph(self, key: Hashable, graph: torch.cuda.CUDAGraph) -> None: + def set_graph(self, key: GraphSignature, graph: torch.cuda.CUDAGraph) -> None: # In our use case, this should not have any effect because we plan for a new graph before it is captured self.plan_for_new_graph() logger.info(f"Setting graph for {key = }") diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index 580db7710e42..6b5a43ae0c42 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -40,11 +40,7 @@ ) from transformers.generation.continuous_batching.cache_manager import FullAttentionCacheAllocator from transformers.generation.continuous_batching.continuous_api import ContinuousBatchProcessor, OutputRouter -from transformers.generation.continuous_batching.input_outputs import ( - ContinuousBatchingAsyncIOs, - ContinuousBatchingIOs, - build_attention_mask, -) +from transformers.generation.continuous_batching.input_outputs import ContinuousBatchingIOs, build_attention_mask from transformers.generation.continuous_batching.requests import GenerationOutput, RequestStatus from transformers.testing_utils import ( require_deterministic_for_xpu, @@ -237,130 +233,28 @@ def prepare_tensor_args(self, requests_in_batch: list[Any], arg_storage: torch.T io.use_block_table = False io.num_q_tokens = 16 io.max_kv_read = 32 - io.max_seqlen_q = 7 - io.max_seqlen_k = dict.fromkeys(io.max_seqlen_k, 19) - base_signature = io.get_graph_signature() + base_batch_data = {"max_seqlen_q": 7, "max_seqlen_k": dict.fromkeys(io.max_seqlen_k, 19)} + base_signature = io.get_graph_signature(base_batch_data) - io.max_seqlen_q = 11 self.assertNotEqual( base_signature, - io.get_graph_signature(), + io.get_graph_signature({"max_seqlen_q": 11, "max_seqlen_k": base_batch_data["max_seqlen_k"]}), "Graph signature should change when max_seqlen_q changes under the same padded Q/KV sizes.", ) - io.max_seqlen_q = 7 - io.max_seqlen_k = dict.fromkeys(io.max_seqlen_k, 23) self.assertNotEqual( base_signature, - io.get_graph_signature(), + io.get_graph_signature({"max_seqlen_q": 7, "max_seqlen_k": dict.fromkeys(io.max_seqlen_k, 23)}), "Graph signature should change when max_seqlen_k changes under the same padded Q/KV sizes.", ) - io.max_seqlen_k = dict.fromkeys(io.max_seqlen_k, 19) io.use_block_table = True self.assertNotEqual( base_signature, - io.get_graph_signature(), + io.get_graph_signature(base_batch_data), "Graph signature should change when switching between varlen and decode-fast-path batches.", ) - def test_transfer_inputs_copies_graph_bounds(self) -> None: - """Host-to-device IO transfers must keep graph-bound metadata in sync with tensor inputs.""" - - class _DummyLogitsProcessor: - tensors_required = 0 - - def fill_defaults(self, arg_storage: torch.Tensor) -> None: - return None - - def prepare_tensor_args(self, requests_in_batch: list[Any], arg_storage: torch.Tensor) -> None: - return None - - config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B", attn_implementation="flash_attention_2") - cache = PagedAttentionCache( - config=config, - continuous_batching_config=ContinuousBatchingConfig(block_size=16, num_blocks=8, max_batch_tokens=32), - device="cpu", - ) - source = ContinuousBatchingIOs( - cache=cache, - config=config, - device=torch.device("cpu"), - model_dtype=torch.float16, - max_graphs=2, - return_logprobs=False, - logit_processor=_DummyLogitsProcessor(), - ) - target = ContinuousBatchingIOs( - cache=cache, - config=config, - device=torch.device("cpu"), - model_dtype=torch.float16, - max_graphs=2, - return_logprobs=False, - logit_processor=_DummyLogitsProcessor(), - ) - - source.num_q_tokens = 16 - source.max_kv_read = 32 - source.max_seqlen_q = 7 - source.max_seqlen_k = dict.fromkeys(source.max_seqlen_k, 19) - source.set_graph_bounds(11, 23) - - source._transfer_inputs(target, stream=None) - - self.assertEqual(target.graph_max_seqlen_q, 11) - self.assertEqual(target.graph_max_seqlen_k, dict.fromkeys(target.graph_max_seqlen_k, 23)) - - -@require_torch_gpu -class ContinuousBatchingAsyncIOTest(unittest.TestCase): - def test_async_io_exposes_graph_bounds_interface(self) -> None: - """Async IOs should expose the same graph-bound interface as sync IOs.""" - - class _DummyLogitsProcessor: - tensors_required = 0 - - def fill_defaults(self, arg_storage: torch.Tensor) -> None: - return None - - def prepare_tensor_args(self, requests_in_batch: list[Any], arg_storage: torch.Tensor) -> None: - return None - - config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B", attn_implementation="flash_attention_2") - cache = PagedAttentionCache( - config=config, - continuous_batching_config=ContinuousBatchingConfig(block_size=16, num_blocks=8, max_batch_tokens=32), - device=torch_device, - ) - io = ContinuousBatchingAsyncIOs( - cache=cache, - config=config, - device=torch.device(torch_device), - model_dtype=torch.float16, - max_graphs=2, - return_logprobs=False, - logit_processor=_DummyLogitsProcessor(), - ) - - self.assertEqual(io.num_q_tokens, 0) - self.assertEqual(io.max_kv_read, 0) - self.assertEqual(io.max_seqlen_q, 0) - self.assertEqual(io.max_seqlen_k, dict.fromkeys(io.max_seqlen_k, 0)) - - io.set_graph_bounds(13, 29) - - current_pair = io.io_pairs[io.current_pair] - self.assertEqual(current_pair.host_io.graph_max_seqlen_q, 13) - self.assertEqual(current_pair.device_io.graph_max_seqlen_q, 13) - self.assertEqual(current_pair.host_io.graph_max_seqlen_k, dict.fromkeys(current_pair.host_io.graph_max_seqlen_k, 29)) - self.assertEqual( - current_pair.device_io.graph_max_seqlen_k, - dict.fromkeys(current_pair.device_io.graph_max_seqlen_k, 29), - ) - - io.reset() - @parameterized.expand( [ (None, None, "0"), From 9a9d4eda48dc838c7d4aa231bccf1651f1bb28f2 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 7 Apr 2026 00:43:00 +0000 Subject: [PATCH 07/10] enable fa2 decode fast path --- .../generation/configuration_utils.py | 11 +++++---- .../generation/continuous_batching/cache.py | 18 ++++++++------ .../continuous_batching/continuous_api.py | 8 ++++--- tests/generation/test_continuous_batching.py | 24 +++++++++++++++++++ 4 files changed, 46 insertions(+), 15 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index f8702dde2942..ccb8862615f3 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -1558,9 +1558,10 @@ class ContinuousBatchingConfig: Maximum number of tokens in a batch. Auto-inferred from GPU memory when `None`. max_memory_percent (`float`, *optional*, defaults to 0.8): Maximum percentage of free GPU memory (after the model is loaded) to use for the KV cache. - max_blocks_per_request (`int`, *optional*, defaults to 0): + max_blocks_per_request (`int`, *optional*): Maximum blocks per request, used in the `flash_attn_with_kvcache` fast decode path to dimension - the block table. Setting this to 0 disables the fast decode path. + the block table. When left unset, continuous batching can choose a FlashAttention-specific default. + Setting this to 0 disables the fast decode path explicitly. allow_block_sharing (`bool`, *optional*, defaults to `True`): Whether to allow block sharing for prefix caching. Block sharing can only be allowed, never forced, as some models do not support it. Disable if you have few short prompts but long generation lengths. @@ -1608,9 +1609,9 @@ class ContinuousBatchingConfig: # The max percentage of free GPU memory (after the model is loaded) to use for the KV cache. max_memory_percent: float = 0.8 - # This is only used in the flash_attn_with_kvcache fast decode path to dimension the block table. If it is set to 0, - # the fast decode path will not be used. Currently turned off by default. - max_blocks_per_request: int | None = 0 + # This is only used in the flash_attn_with_kvcache fast decode path to dimension the block table. None keeps the + # door open for backend-specific defaults, while 0 still disables the path explicitly. + max_blocks_per_request: int | None = None # Block sharing can only be allowed, but never forced: some model just do not support it. If you only have a few # short prompts, but long generation lengths, you might want to disable block sharing. diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index 9fd0d3afba11..00f65a51fd5a 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -58,6 +58,15 @@ def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]] return layer_groups, group_types +def default_flash_attention_max_blocks_per_request(config: PreTrainedConfig, max_batch_tokens: int) -> int: + """Pick the decode-fast-path block-table size for FlashAttention continuous batching.""" + if not ( + is_flash_attention_requested(config, version=2) or is_flash_attention_requested(config, version=3) + ): + return 0 + return 16 if max_batch_tokens > 4096 else 1 + + @attach_tracer() class PagedAttentionCache: """ @@ -207,15 +216,10 @@ def __init__( f"{self.max_batch_tokens = } {num_attention_masks = }" ) - # If max_blocks_per_request is not set, the default value is 16 max blocks. With default block size of 256, this - # means a max sequence length of 4096 tokens for the fast decode path. + # Resolve the block-table size after max_batch_tokens is known so the auto path can scale with the final cache. max_blocks_per_request = continuous_batching_config.max_blocks_per_request if max_blocks_per_request is None: - max_blocks_per_request = 0 - # logger.info( TODO: uncomment when we have good defaults - # f"max_blocks_per_request was not set, using {max_blocks_per_request}. This means max sequence " - # f"length for the decode fast path is {max_blocks_per_request * self.block_size}." - # ) + max_blocks_per_request = default_flash_attention_max_blocks_per_request(config, self.max_batch_tokens) self.max_blocks_per_request = max_blocks_per_request # Initialize the cache diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 14572d3b648a..b481ecfd61de 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -257,8 +257,10 @@ def __del__(self) -> None: def _ensure_decode_fast_path_is_available(self) -> None: """Ensures the decode fast path is available. If it is not, set the max blocks per request to 0.""" if self.cache.max_blocks_per_request > 0: - # NOTE: block table should be available with FA2 and FA3, but there seems to be an issue with FA2 atm - if is_flash_attention_requested(self.config, version=3): + # flash_attn_with_kvcache underpins the paged decode fast path for both FA2 and FA3. + if is_flash_attention_requested(self.config, version=2) or is_flash_attention_requested( + self.config, version=3 + ): flash_attn_with_kvcache = lazy_import_paged_flash_attention(self.config._attn_implementation)[1] conditions = [ self.cache.num_sliding_attention_groups == 0, # TODO: add support for sliding window layers @@ -274,7 +276,7 @@ def _ensure_decode_fast_path_is_available(self) -> None: else: logger.warning( f"Although {self.cache.max_blocks_per_request = }, the decode fast path is not available " - f"because the attention implementation is not FA3. Got {self.config._attn_implementation = }." + f"because the attention implementation is not FA2/FA3. Got {self.config._attn_implementation = }." ) self.cache.max_blocks_per_request = 0 diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index 6b5a43ae0c42..08e56cd17886 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -16,6 +16,7 @@ import gc import itertools import unittest +from types import SimpleNamespace from typing import Any from unittest.mock import patch @@ -36,6 +37,7 @@ PagedAttentionCache, PagedAttentionMemoryHandler, SlidingAttentionCacheAllocator, + default_flash_attention_max_blocks_per_request, group_layers_by_attn_type, ) from transformers.generation.continuous_batching.cache_manager import FullAttentionCacheAllocator @@ -202,6 +204,28 @@ def regular_generate( # Class for all continuous batching tests that do not require any accelerator. Usualy those test are faster to run. class ContinuousBatchingNoAcceleratorTest(unittest.TestCase): + def test_flash_attention_auto_max_blocks_per_request(self) -> None: + """FlashAttention auto-defaults should scale the block table with the resolved batch-token budget.""" + + config = SimpleNamespace(_attn_implementation="flash_attention_2") + self.assertEqual(default_flash_attention_max_blocks_per_request(config, 2048), 1) + self.assertEqual(default_flash_attention_max_blocks_per_request(config, 8192), 16) + self.assertEqual(default_flash_attention_max_blocks_per_request(SimpleNamespace(_attn_implementation="sdpa"), 8192), 0) + + def test_decode_fast_path_accepts_fa2(self) -> None: + """FA2 should keep the block-table decode path enabled when flash_attn_with_kvcache is available.""" + + cache = SimpleNamespace(max_blocks_per_request=4, num_sliding_attention_groups=0) + processor = SimpleNamespace(cache=cache, config=SimpleNamespace(_attn_implementation="flash_attention_2")) + + with patch("torch.cuda.is_available", return_value=True), patch( + "transformers.generation.continuous_batching.continuous_api.lazy_import_paged_flash_attention", + return_value=(None, object()), + ): + ContinuousBatchProcessor._ensure_decode_fast_path_is_available(processor) + + self.assertEqual(cache.max_blocks_per_request, 4) + def test_cuda_graph_signature_tracks_non_tensor_runtime_args(self) -> None: """CUDA graph reuse must distinguish batches that share padded tensor sizes but not FA runtime ints.""" From f28d9299341ec42aaa2f9269202b074a3f365353 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 7 Apr 2026 01:48:28 +0000 Subject: [PATCH 08/10] restore cb max_blocks default --- src/transformers/generation/configuration_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index ccb8862615f3..bf17963012a3 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -1558,10 +1558,9 @@ class ContinuousBatchingConfig: Maximum number of tokens in a batch. Auto-inferred from GPU memory when `None`. max_memory_percent (`float`, *optional*, defaults to 0.8): Maximum percentage of free GPU memory (after the model is loaded) to use for the KV cache. - max_blocks_per_request (`int`, *optional*): + max_blocks_per_request (`int`, *optional*, defaults to 0): Maximum blocks per request, used in the `flash_attn_with_kvcache` fast decode path to dimension - the block table. When left unset, continuous batching can choose a FlashAttention-specific default. - Setting this to 0 disables the fast decode path explicitly. + the block table. Setting this to 0 disables the fast decode path. allow_block_sharing (`bool`, *optional*, defaults to `True`): Whether to allow block sharing for prefix caching. Block sharing can only be allowed, never forced, as some models do not support it. Disable if you have few short prompts but long generation lengths. @@ -1609,9 +1608,10 @@ class ContinuousBatchingConfig: # The max percentage of free GPU memory (after the model is loaded) to use for the KV cache. max_memory_percent: float = 0.8 - # This is only used in the flash_attn_with_kvcache fast decode path to dimension the block table. None keeps the - # door open for backend-specific defaults, while 0 still disables the path explicitly. - max_blocks_per_request: int | None = None + # This is only used in the flash_attn_with_kvcache fast decode path to dimension the block table. If it is set to 0, + # the fast decode path will not be used. Callers can still pass None explicitly if they want backend-specific + # defaults to decide the block table size later. + max_blocks_per_request: int | None = 0 # Block sharing can only be allowed, but never forced: some model just do not support it. If you only have a few # short prompts, but long generation lengths, you might want to disable block sharing. From 45b1271a2629257b3ee3fc0d511c9bd696e5b879 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 7 Apr 2026 01:49:50 +0000 Subject: [PATCH 09/10] restore cb config comment --- src/transformers/generation/configuration_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index bf17963012a3..f8702dde2942 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -1609,8 +1609,7 @@ class ContinuousBatchingConfig: max_memory_percent: float = 0.8 # This is only used in the flash_attn_with_kvcache fast decode path to dimension the block table. If it is set to 0, - # the fast decode path will not be used. Callers can still pass None explicitly if they want backend-specific - # defaults to decide the block table size later. + # the fast decode path will not be used. Currently turned off by default. max_blocks_per_request: int | None = 0 # Block sharing can only be allowed, but never forced: some model just do not support it. If you only have a few From f921f65912968d16802cac1b0f4f3c6ef6b9a732 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 7 Apr 2026 02:16:18 +0000 Subject: [PATCH 10/10] style --- .../generation/continuous_batching/cache.py | 4 +--- .../generation/continuous_batching/input_outputs.py | 4 +++- tests/generation/test_continuous_batching.py | 13 +++++++++---- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index 00f65a51fd5a..0432b4853c48 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -60,9 +60,7 @@ def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]] def default_flash_attention_max_blocks_per_request(config: PreTrainedConfig, max_batch_tokens: int) -> int: """Pick the decode-fast-path block-table size for FlashAttention continuous batching.""" - if not ( - is_flash_attention_requested(config, version=2) or is_flash_attention_requested(config, version=3) - ): + if not (is_flash_attention_requested(config, version=2) or is_flash_attention_requested(config, version=3)): return 0 return 16 if max_batch_tokens > 4096 else 1 diff --git a/src/transformers/generation/continuous_batching/input_outputs.py b/src/transformers/generation/continuous_batching/input_outputs.py index 471299f68e54..4d376dbf2c78 100644 --- a/src/transformers/generation/continuous_batching/input_outputs.py +++ b/src/transformers/generation/continuous_batching/input_outputs.py @@ -550,7 +550,9 @@ def get_graph_signature( max_seqlen_q = pad_to_interval(max_seqlen_q, q_padding_interval_size, q_tokens) max_kv_total = max_kv_read + q_tokens if self.use_block_table: - max_seqlen_k_signature = tuple((layer_type, 1) for layer_type in max_seqlen_k) if isinstance(max_seqlen_k, dict) else 1 + max_seqlen_k_signature = ( + tuple((layer_type, 1) for layer_type in max_seqlen_k) if isinstance(max_seqlen_k, dict) else 1 + ) else: max_seqlen_k_signature = ( tuple( diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index 08e56cd17886..abd0d54dd99f 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -210,7 +210,9 @@ def test_flash_attention_auto_max_blocks_per_request(self) -> None: config = SimpleNamespace(_attn_implementation="flash_attention_2") self.assertEqual(default_flash_attention_max_blocks_per_request(config, 2048), 1) self.assertEqual(default_flash_attention_max_blocks_per_request(config, 8192), 16) - self.assertEqual(default_flash_attention_max_blocks_per_request(SimpleNamespace(_attn_implementation="sdpa"), 8192), 0) + self.assertEqual( + default_flash_attention_max_blocks_per_request(SimpleNamespace(_attn_implementation="sdpa"), 8192), 0 + ) def test_decode_fast_path_accepts_fa2(self) -> None: """FA2 should keep the block-table decode path enabled when flash_attn_with_kvcache is available.""" @@ -218,9 +220,12 @@ def test_decode_fast_path_accepts_fa2(self) -> None: cache = SimpleNamespace(max_blocks_per_request=4, num_sliding_attention_groups=0) processor = SimpleNamespace(cache=cache, config=SimpleNamespace(_attn_implementation="flash_attention_2")) - with patch("torch.cuda.is_available", return_value=True), patch( - "transformers.generation.continuous_batching.continuous_api.lazy_import_paged_flash_attention", - return_value=(None, object()), + with ( + patch("torch.cuda.is_available", return_value=True), + patch( + "transformers.generation.continuous_batching.continuous_api.lazy_import_paged_flash_attention", + return_value=(None, object()), + ), ): ContinuousBatchProcessor._ensure_decode_fast_path_is_available(processor)