diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index 9fd0d3afba11..0432b4853c48 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -58,6 +58,13 @@ def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]] return layer_groups, group_types +def default_flash_attention_max_blocks_per_request(config: PreTrainedConfig, max_batch_tokens: int) -> int: + """Pick the decode-fast-path block-table size for FlashAttention continuous batching.""" + if not (is_flash_attention_requested(config, version=2) or is_flash_attention_requested(config, version=3)): + return 0 + return 16 if max_batch_tokens > 4096 else 1 + + @attach_tracer() class PagedAttentionCache: """ @@ -207,15 +214,10 @@ def __init__( f"{self.max_batch_tokens = } {num_attention_masks = }" ) - # If max_blocks_per_request is not set, the default value is 16 max blocks. With default block size of 256, this - # means a max sequence length of 4096 tokens for the fast decode path. + # Resolve the block-table size after max_batch_tokens is known so the auto path can scale with the final cache. max_blocks_per_request = continuous_batching_config.max_blocks_per_request if max_blocks_per_request is None: - max_blocks_per_request = 0 - # logger.info( TODO: uncomment when we have good defaults - # f"max_blocks_per_request was not set, using {max_blocks_per_request}. This means max sequence " - # f"length for the decode fast path is {max_blocks_per_request * self.block_size}." - # ) + max_blocks_per_request = default_flash_attention_max_blocks_per_request(config, self.max_batch_tokens) self.max_blocks_per_request = max_blocks_per_request # Initialize the cache diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 4e2ef2a24adb..b481ecfd61de 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -257,8 +257,10 @@ def __del__(self) -> None: def _ensure_decode_fast_path_is_available(self) -> None: """Ensures the decode fast path is available. If it is not, set the max blocks per request to 0.""" if self.cache.max_blocks_per_request > 0: - # NOTE: block table should be available with FA2 and FA3, but there seems to be an issue with FA2 atm - if is_flash_attention_requested(self.config, version=3): + # flash_attn_with_kvcache underpins the paged decode fast path for both FA2 and FA3. + if is_flash_attention_requested(self.config, version=2) or is_flash_attention_requested( + self.config, version=3 + ): flash_attn_with_kvcache = lazy_import_paged_flash_attention(self.config._attn_implementation)[1] conditions = [ self.cache.num_sliding_attention_groups == 0, # TODO: add support for sliding window layers @@ -274,7 +276,7 @@ def _ensure_decode_fast_path_is_available(self) -> None: else: logger.warning( f"Although {self.cache.max_blocks_per_request = }, the decode fast path is not available " - f"because the attention implementation is not FA3. Got {self.config._attn_implementation = }." + f"because the attention implementation is not FA2/FA3. Got {self.config._attn_implementation = }." ) self.cache.max_blocks_per_request = 0 @@ -522,7 +524,9 @@ def _generation_step(self, model: nn.Module) -> None: # Otherwise, we use create or replay the graph (cuda is available in this path) else: - graph = self.inputs_and_outputs.get_graph() + graph = self.inputs_and_outputs.get_graph( + batch_data, self.q_padding_interval_size, self.kv_padding_interval_size + ) # Case: the graph already exists, so we replay it if graph is not None: with torch.cuda.stream(compute_stream): @@ -530,12 +534,12 @@ def _generation_step(self, model: nn.Module) -> None: # Otherwise, the graph does not exist, so we create it else: args = (model, batch_data, carry_over_ids, prev_output_ids, output_ids) - self.capture_graph(forward_fn, compute_stream, *args) + self.capture_graph(forward_fn, compute_stream, batch_data, *args) # In any case, we transfer the outputs to the host self.inputs_and_outputs.retrieve_device_outputs() - def capture_graph(self, forward_fn: Any, compute_stream: torch.cuda.Stream, *args) -> None: + def capture_graph(self, forward_fn: Any, compute_stream: torch.cuda.Stream, batch_data: dict, *args) -> None: # Warmup (ensures the right result is computed before capturing the graph) with torch.cuda.stream(compute_stream): forward_fn(*args) @@ -547,7 +551,9 @@ def capture_graph(self, forward_fn: Any, compute_stream: torch.cuda.Stream, *arg with torch.cuda.graph(graph, stream=compute_stream, pool=self.graph_pool, capture_error_mode="thread_local"): forward_fn(*args) # Store - self.inputs_and_outputs.set_graph(graph) + self.inputs_and_outputs.set_graph( + graph, batch_data, self.q_padding_interval_size, self.kv_padding_interval_size + ) @traced def _forward_process_and_sample( @@ -666,7 +672,7 @@ def warmup( forward_fn = self._compiled_varlen or self._forward_process_and_sample forward_fn_args = (model, batch_data, carry_over_ids, prev_output_ids, output_ids) if self.use_cuda_graph: - self.capture_graph(forward_fn, compute_stream, *forward_fn_args) + self.capture_graph(forward_fn, compute_stream, batch_data, *forward_fn_args) else: with torch.cuda.stream(compute_stream): forward_fn(*forward_fn_args) @@ -704,7 +710,7 @@ def warmup( forward_fn = self._compiled_decode or self._forward_process_and_sample forward_fn_args = (model, batch_data, carry_over_ids, prev_output_ids, output_ids) if self.use_cuda_graph: - self.capture_graph(forward_fn, compute_stream, *forward_fn_args) + self.capture_graph(forward_fn, compute_stream, batch_data, *forward_fn_args) else: with torch.cuda.stream(compute_stream): forward_fn(*forward_fn_args) diff --git a/src/transformers/generation/continuous_batching/input_outputs.py b/src/transformers/generation/continuous_batching/input_outputs.py index b5c4e5419bfc..4d376dbf2c78 100644 --- a/src/transformers/generation/continuous_batching/input_outputs.py +++ b/src/transformers/generation/continuous_batching/input_outputs.py @@ -26,7 +26,7 @@ from .cache import PagedAttentionCache from .cb_logits_processors import ContinuousBatchingLogitsProcessorList from .requests import TMP_TOKEN_ID, FutureRequestState, logger -from .utils import CudaGraphBuffer, aligned_divide, attn_mask_is_needed, build_attention_mask +from .utils import CudaGraphBuffer, aligned_divide, attn_mask_is_needed, build_attention_mask, pad_to_interval @dataclass @@ -483,8 +483,7 @@ def get_model_kwargs(self, use_padding: bool = False) -> dict[str, Any]: # some models like Qwen3-4B-Instruct-2507, if we don't include these tokens in cumulative_seqlens_q, there are # some NaNs in the output logits even for non-padded tokens. if use_padding: - self.max_seqlen_q = max(self.max_seqlen_q, q_size - self.total_seqlen_q) - kwargs.max_seqlen_q = self.max_seqlen_q + kwargs.max_seqlen_q = max(kwargs.max_seqlen_q, q_size - self.total_seqlen_q) self.cumulative_seqlens_q[self.true_batch_size + 1 :] = q_size # FIXME: is there another way to avoid this? It has a very slight impact on performance (~5 tok/s) @@ -531,16 +530,79 @@ def get_cb_kwargs(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return 3 tensors to have the same interface as when using async batching.""" return self.carry_over_ids, self.output_ids, self.output_ids - def get_graph(self) -> torch.cuda.CUDAGraph | None: - graph = self.graphs.get_graph(self.num_q_tokens, self.max_kv_read) + def get_graph_signature( + self, + batch_data: dict[str, Any], + q_padding_interval_size: int = 0, + kv_padding_interval_size: int = 0, + ) -> tuple[Any, ...]: + """Return the CUDA graph signature for the current batch. + + Continuous batching pads tensor inputs to static shapes for CUDA graph replay, but FA-style kernels also depend + on non-tensor integer kwargs such as `max_seqlen_q` and `max_seqlen_k`. Reusing a graph across batches that + share padded `(Q, KV)` sizes but differ on those integers can replay kernels with stale launch parameters. + """ + q_tokens = self.num_q_tokens + max_kv_read = self.max_kv_read + max_seqlen_q = batch_data["max_seqlen_q"] + max_seqlen_k = batch_data["max_seqlen_k"] + if q_padding_interval_size > 0 and not self.use_block_table: + max_seqlen_q = pad_to_interval(max_seqlen_q, q_padding_interval_size, q_tokens) + max_kv_total = max_kv_read + q_tokens + if self.use_block_table: + max_seqlen_k_signature = ( + tuple((layer_type, 1) for layer_type in max_seqlen_k) if isinstance(max_seqlen_k, dict) else 1 + ) + else: + max_seqlen_k_signature = ( + tuple( + ( + layer_type, + pad_to_interval(max_seqlen_k[layer_type], kv_padding_interval_size, max_kv_total) + if kv_padding_interval_size > 0 + else max_seqlen_k[layer_type], + ) + for layer_type in max_seqlen_k + ) + if isinstance(max_seqlen_k, dict) + else ( + pad_to_interval(max_seqlen_k, kv_padding_interval_size, max_kv_total) + if kv_padding_interval_size > 0 + else max_seqlen_k + ) + ) + return ( + "decode_fast_path" if self.use_block_table else "varlen", + q_tokens, + max_kv_read, + 1 if self.use_block_table else max_seqlen_q, + max_seqlen_k_signature, + ) + + def get_graph( + self, + batch_data: dict[str, Any], + q_padding_interval_size: int = 0, + kv_padding_interval_size: int = 0, + ) -> torch.cuda.CUDAGraph | None: + graph_signature = self.get_graph_signature(batch_data, q_padding_interval_size, kv_padding_interval_size) + graph = self.graphs.get_graph(graph_signature) # If this point is reached, it means the next step will be a new graph capture if graph is None: self.graphs.plan_for_new_graph() - logger.info(f"Creating graph for {(self.num_q_tokens, self.max_kv_read) = }") + logger.info(f"Creating graph for {graph_signature = }") return graph - def set_graph(self, graph: torch.cuda.CUDAGraph) -> None: - self.graphs.set_graph(self.num_q_tokens, self.max_kv_read, graph) + def set_graph( + self, + graph: torch.cuda.CUDAGraph, + batch_data: dict[str, Any], + q_padding_interval_size: int = 0, + kv_padding_interval_size: int = 0, + ) -> None: + self.graphs.set_graph( + self.get_graph_signature(batch_data, q_padding_interval_size, kv_padding_interval_size), graph + ) class HostDeviceIOPair: @@ -663,6 +725,14 @@ def __init__( def get_cumulative_seqlens(self) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: return self.io_pairs[self.current_pair].host_io.get_cumulative_seqlens() + @property + def num_q_tokens(self) -> int: + return self.io_pairs[self.current_pair].host_io.num_q_tokens + + @property + def max_kv_read(self) -> int: + return self.io_pairs[self.current_pair].host_io.max_kv_read + # The prepare_batch_tensor method also has to prepare the carry over ids def prepare_batch_tensors( self, @@ -742,11 +812,26 @@ def output_ids(self) -> torch.Tensor: # The output ids are used to copy_ the infered tokens: they need to be on the device return self.io_pairs[self.current_pair].device_io.output_ids - def get_graph(self) -> torch.cuda.CUDAGraph | None: - return self.io_pairs[self.current_pair].device_io.get_graph() + def get_graph( + self, + batch_data: dict[str, Any], + q_padding_interval_size: int = 0, + kv_padding_interval_size: int = 0, + ) -> torch.cuda.CUDAGraph | None: + return self.io_pairs[self.current_pair].device_io.get_graph( + batch_data, q_padding_interval_size, kv_padding_interval_size + ) - def set_graph(self, graph: torch.cuda.CUDAGraph) -> None: - self.io_pairs[self.current_pair].device_io.set_graph(graph) + def set_graph( + self, + graph: torch.cuda.CUDAGraph, + batch_data: dict[str, Any], + q_padding_interval_size: int = 0, + kv_padding_interval_size: int = 0, + ) -> None: + self.io_pairs[self.current_pair].device_io.set_graph( + graph, batch_data, q_padding_interval_size, kv_padding_interval_size + ) @property def use_block_table(self) -> bool: diff --git a/src/transformers/generation/continuous_batching/utils.py b/src/transformers/generation/continuous_batching/utils.py index a827d247110e..6ccc881573b1 100644 --- a/src/transformers/generation/continuous_batching/utils.py +++ b/src/transformers/generation/continuous_batching/utils.py @@ -22,6 +22,9 @@ from .requests import FutureRequestState, RequestState, RequestStatus, logger +GraphSignature = tuple[Any, ...] + + class CudaGraphBuffer: """A fixed-size dict for CUDA graphs with LRU eviction when full.""" @@ -29,7 +32,7 @@ def __init__(self, max_size: int) -> None: if max_size <= 0: raise ValueError(f"max_size must be positive, but got {max_size}") self.max_size = max_size - self._storage: OrderedDict[tuple[int, int], torch.cuda.CUDAGraph] = OrderedDict() + self._storage: OrderedDict[GraphSignature, torch.cuda.CUDAGraph] = OrderedDict() def __del__(self) -> None: original_max_size = self.max_size @@ -37,10 +40,11 @@ def __del__(self) -> None: self.plan_for_new_graph(silent=True) self.max_size = original_max_size - def get_graph(self, q_len: int, kv_len: int) -> torch.cuda.CUDAGraph | None: - graph = self._storage.get((q_len, kv_len)) + def get_graph(self, key: GraphSignature) -> torch.cuda.CUDAGraph | None: + # Replay safety sometimes depends on scalar kwargs as well as tensor extents, so the key is generic. + graph = self._storage.get(key) if graph is not None: - self._storage.move_to_end((q_len, kv_len)) + self._storage.move_to_end(key) return graph def plan_for_new_graph(self, silent: bool = False) -> None: @@ -50,11 +54,11 @@ def plan_for_new_graph(self, silent: bool = False) -> None: logger.info(f"Evicting graph for {evicted_key = }") evicted_graph.reset() - def set_graph(self, q_len: int, kv_len: int, graph: torch.cuda.CUDAGraph) -> None: + def set_graph(self, key: GraphSignature, graph: torch.cuda.CUDAGraph) -> None: # In our use case, this should not have any effect because we plan for a new graph before it is captured self.plan_for_new_graph() - logger.info(f"Setting graph for {q_len = }, {kv_len = }") - self._storage[(q_len, kv_len)] = graph + logger.info(f"Setting graph for {key = }") + self._storage[key] = graph def attn_mask_is_needed(config: PretrainedConfig) -> bool: diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index 98566466a060..abd0d54dd99f 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -16,6 +16,7 @@ import gc import itertools import unittest +from types import SimpleNamespace from typing import Any from unittest.mock import patch @@ -36,11 +37,12 @@ PagedAttentionCache, PagedAttentionMemoryHandler, SlidingAttentionCacheAllocator, + default_flash_attention_max_blocks_per_request, group_layers_by_attn_type, ) from transformers.generation.continuous_batching.cache_manager import FullAttentionCacheAllocator from transformers.generation.continuous_batching.continuous_api import ContinuousBatchProcessor, OutputRouter -from transformers.generation.continuous_batching.input_outputs import build_attention_mask +from transformers.generation.continuous_batching.input_outputs import ContinuousBatchingIOs, build_attention_mask from transformers.generation.continuous_batching.requests import GenerationOutput, RequestStatus from transformers.testing_utils import ( require_deterministic_for_xpu, @@ -202,6 +204,86 @@ def regular_generate( # Class for all continuous batching tests that do not require any accelerator. Usualy those test are faster to run. class ContinuousBatchingNoAcceleratorTest(unittest.TestCase): + def test_flash_attention_auto_max_blocks_per_request(self) -> None: + """FlashAttention auto-defaults should scale the block table with the resolved batch-token budget.""" + + config = SimpleNamespace(_attn_implementation="flash_attention_2") + self.assertEqual(default_flash_attention_max_blocks_per_request(config, 2048), 1) + self.assertEqual(default_flash_attention_max_blocks_per_request(config, 8192), 16) + self.assertEqual( + default_flash_attention_max_blocks_per_request(SimpleNamespace(_attn_implementation="sdpa"), 8192), 0 + ) + + def test_decode_fast_path_accepts_fa2(self) -> None: + """FA2 should keep the block-table decode path enabled when flash_attn_with_kvcache is available.""" + + cache = SimpleNamespace(max_blocks_per_request=4, num_sliding_attention_groups=0) + processor = SimpleNamespace(cache=cache, config=SimpleNamespace(_attn_implementation="flash_attention_2")) + + with ( + patch("torch.cuda.is_available", return_value=True), + patch( + "transformers.generation.continuous_batching.continuous_api.lazy_import_paged_flash_attention", + return_value=(None, object()), + ), + ): + ContinuousBatchProcessor._ensure_decode_fast_path_is_available(processor) + + self.assertEqual(cache.max_blocks_per_request, 4) + + def test_cuda_graph_signature_tracks_non_tensor_runtime_args(self) -> None: + """CUDA graph reuse must distinguish batches that share padded tensor sizes but not FA runtime ints.""" + + class _DummyLogitsProcessor: + tensors_required = 0 + + def fill_defaults(self, arg_storage: torch.Tensor) -> None: + return None + + def prepare_tensor_args(self, requests_in_batch: list[Any], arg_storage: torch.Tensor) -> None: + return None + + config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B", attn_implementation="flash_attention_2") + cache = PagedAttentionCache( + config=config, + continuous_batching_config=ContinuousBatchingConfig(block_size=16, num_blocks=8, max_batch_tokens=32), + device="cpu", + ) + io = ContinuousBatchingIOs( + cache=cache, + config=config, + device=torch.device("cpu"), + model_dtype=torch.float16, + max_graphs=2, + return_logprobs=False, + logit_processor=_DummyLogitsProcessor(), + ) + + io.use_block_table = False + io.num_q_tokens = 16 + io.max_kv_read = 32 + base_batch_data = {"max_seqlen_q": 7, "max_seqlen_k": dict.fromkeys(io.max_seqlen_k, 19)} + base_signature = io.get_graph_signature(base_batch_data) + + self.assertNotEqual( + base_signature, + io.get_graph_signature({"max_seqlen_q": 11, "max_seqlen_k": base_batch_data["max_seqlen_k"]}), + "Graph signature should change when max_seqlen_q changes under the same padded Q/KV sizes.", + ) + + self.assertNotEqual( + base_signature, + io.get_graph_signature({"max_seqlen_q": 7, "max_seqlen_k": dict.fromkeys(io.max_seqlen_k, 23)}), + "Graph signature should change when max_seqlen_k changes under the same padded Q/KV sizes.", + ) + + io.use_block_table = True + self.assertNotEqual( + base_signature, + io.get_graph_signature(base_batch_data), + "Graph signature should change when switching between varlen and decode-fast-path batches.", + ) + @parameterized.expand( [ (None, None, "0"),