Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/transformers/generation/continuous_batching/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]]
return layer_groups, group_types


def default_flash_attention_max_blocks_per_request(config: PreTrainedConfig, max_batch_tokens: int) -> int:
"""Pick the decode-fast-path block-table size for FlashAttention continuous batching."""
if not (is_flash_attention_requested(config, version=2) or is_flash_attention_requested(config, version=3)):
return 0
return 16 if max_batch_tokens > 4096 else 1


@attach_tracer()
class PagedAttentionCache:
"""
Expand Down Expand Up @@ -207,15 +214,10 @@ def __init__(
f"{self.max_batch_tokens = } {num_attention_masks = }"
)

# If max_blocks_per_request is not set, the default value is 16 max blocks. With default block size of 256, this
# means a max sequence length of 4096 tokens for the fast decode path.
# Resolve the block-table size after max_batch_tokens is known so the auto path can scale with the final cache.
max_blocks_per_request = continuous_batching_config.max_blocks_per_request
if max_blocks_per_request is None:
max_blocks_per_request = 0
# logger.info( TODO: uncomment when we have good defaults
# f"max_blocks_per_request was not set, using {max_blocks_per_request}. This means max sequence "
# f"length for the decode fast path is {max_blocks_per_request * self.block_size}."
# )
max_blocks_per_request = default_flash_attention_max_blocks_per_request(config, self.max_batch_tokens)
self.max_blocks_per_request = max_blocks_per_request

# Initialize the cache
Expand Down
24 changes: 15 additions & 9 deletions src/transformers/generation/continuous_batching/continuous_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,10 @@ def __del__(self) -> None:
def _ensure_decode_fast_path_is_available(self) -> None:
"""Ensures the decode fast path is available. If it is not, set the max blocks per request to 0."""
if self.cache.max_blocks_per_request > 0:
# NOTE: block table should be available with FA2 and FA3, but there seems to be an issue with FA2 atm
if is_flash_attention_requested(self.config, version=3):
# flash_attn_with_kvcache underpins the paged decode fast path for both FA2 and FA3.
if is_flash_attention_requested(self.config, version=2) or is_flash_attention_requested(
self.config, version=3
):
flash_attn_with_kvcache = lazy_import_paged_flash_attention(self.config._attn_implementation)[1]
conditions = [
self.cache.num_sliding_attention_groups == 0, # TODO: add support for sliding window layers
Expand All @@ -274,7 +276,7 @@ def _ensure_decode_fast_path_is_available(self) -> None:
else:
logger.warning(
f"Although {self.cache.max_blocks_per_request = }, the decode fast path is not available "
f"because the attention implementation is not FA3. Got {self.config._attn_implementation = }."
f"because the attention implementation is not FA2/FA3. Got {self.config._attn_implementation = }."
)
self.cache.max_blocks_per_request = 0

Expand Down Expand Up @@ -522,20 +524,22 @@ def _generation_step(self, model: nn.Module) -> None:

# Otherwise, we use create or replay the graph (cuda is available in this path)
else:
graph = self.inputs_and_outputs.get_graph()
graph = self.inputs_and_outputs.get_graph(
batch_data, self.q_padding_interval_size, self.kv_padding_interval_size
)
# Case: the graph already exists, so we replay it
if graph is not None:
with torch.cuda.stream(compute_stream):
graph.replay()
# Otherwise, the graph does not exist, so we create it
else:
args = (model, batch_data, carry_over_ids, prev_output_ids, output_ids)
self.capture_graph(forward_fn, compute_stream, *args)
self.capture_graph(forward_fn, compute_stream, batch_data, *args)

# In any case, we transfer the outputs to the host
self.inputs_and_outputs.retrieve_device_outputs()

def capture_graph(self, forward_fn: Any, compute_stream: torch.cuda.Stream, *args) -> None:
def capture_graph(self, forward_fn: Any, compute_stream: torch.cuda.Stream, batch_data: dict, *args) -> None:
# Warmup (ensures the right result is computed before capturing the graph)
with torch.cuda.stream(compute_stream):
forward_fn(*args)
Expand All @@ -547,7 +551,9 @@ def capture_graph(self, forward_fn: Any, compute_stream: torch.cuda.Stream, *arg
with torch.cuda.graph(graph, stream=compute_stream, pool=self.graph_pool, capture_error_mode="thread_local"):
forward_fn(*args)
# Store
self.inputs_and_outputs.set_graph(graph)
self.inputs_and_outputs.set_graph(
graph, batch_data, self.q_padding_interval_size, self.kv_padding_interval_size
)

@traced
def _forward_process_and_sample(
Expand Down Expand Up @@ -666,7 +672,7 @@ def warmup(
forward_fn = self._compiled_varlen or self._forward_process_and_sample
forward_fn_args = (model, batch_data, carry_over_ids, prev_output_ids, output_ids)
if self.use_cuda_graph:
self.capture_graph(forward_fn, compute_stream, *forward_fn_args)
self.capture_graph(forward_fn, compute_stream, batch_data, *forward_fn_args)
else:
with torch.cuda.stream(compute_stream):
forward_fn(*forward_fn_args)
Expand Down Expand Up @@ -704,7 +710,7 @@ def warmup(
forward_fn = self._compiled_decode or self._forward_process_and_sample
forward_fn_args = (model, batch_data, carry_over_ids, prev_output_ids, output_ids)
if self.use_cuda_graph:
self.capture_graph(forward_fn, compute_stream, *forward_fn_args)
self.capture_graph(forward_fn, compute_stream, batch_data, *forward_fn_args)
else:
with torch.cuda.stream(compute_stream):
forward_fn(*forward_fn_args)
Expand Down
109 changes: 97 additions & 12 deletions src/transformers/generation/continuous_batching/input_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .cache import PagedAttentionCache
from .cb_logits_processors import ContinuousBatchingLogitsProcessorList
from .requests import TMP_TOKEN_ID, FutureRequestState, logger
from .utils import CudaGraphBuffer, aligned_divide, attn_mask_is_needed, build_attention_mask
from .utils import CudaGraphBuffer, aligned_divide, attn_mask_is_needed, build_attention_mask, pad_to_interval


@dataclass
Expand Down Expand Up @@ -483,8 +483,7 @@ def get_model_kwargs(self, use_padding: bool = False) -> dict[str, Any]:
# some models like Qwen3-4B-Instruct-2507, if we don't include these tokens in cumulative_seqlens_q, there are
# some NaNs in the output logits even for non-padded tokens.
if use_padding:
self.max_seqlen_q = max(self.max_seqlen_q, q_size - self.total_seqlen_q)
kwargs.max_seqlen_q = self.max_seqlen_q
kwargs.max_seqlen_q = max(kwargs.max_seqlen_q, q_size - self.total_seqlen_q)
self.cumulative_seqlens_q[self.true_batch_size + 1 :] = q_size
# FIXME: is there another way to avoid this? It has a very slight impact on performance (~5 tok/s)

Expand Down Expand Up @@ -531,16 +530,79 @@ def get_cb_kwargs(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return 3 tensors to have the same interface as when using async batching."""
return self.carry_over_ids, self.output_ids, self.output_ids

def get_graph(self) -> torch.cuda.CUDAGraph | None:
graph = self.graphs.get_graph(self.num_q_tokens, self.max_kv_read)
def get_graph_signature(
self,
batch_data: dict[str, Any],
q_padding_interval_size: int = 0,
kv_padding_interval_size: int = 0,
) -> tuple[Any, ...]:
"""Return the CUDA graph signature for the current batch.

Continuous batching pads tensor inputs to static shapes for CUDA graph replay, but FA-style kernels also depend
on non-tensor integer kwargs such as `max_seqlen_q` and `max_seqlen_k`. Reusing a graph across batches that
share padded `(Q, KV)` sizes but differ on those integers can replay kernels with stale launch parameters.
"""
q_tokens = self.num_q_tokens
max_kv_read = self.max_kv_read
max_seqlen_q = batch_data["max_seqlen_q"]
max_seqlen_k = batch_data["max_seqlen_k"]
if q_padding_interval_size > 0 and not self.use_block_table:
max_seqlen_q = pad_to_interval(max_seqlen_q, q_padding_interval_size, q_tokens)
max_kv_total = max_kv_read + q_tokens
if self.use_block_table:
max_seqlen_k_signature = (
tuple((layer_type, 1) for layer_type in max_seqlen_k) if isinstance(max_seqlen_k, dict) else 1
)
else:
max_seqlen_k_signature = (
tuple(
(
layer_type,
pad_to_interval(max_seqlen_k[layer_type], kv_padding_interval_size, max_kv_total)
if kv_padding_interval_size > 0
else max_seqlen_k[layer_type],
)
for layer_type in max_seqlen_k
)
if isinstance(max_seqlen_k, dict)
else (
pad_to_interval(max_seqlen_k, kv_padding_interval_size, max_kv_total)
if kv_padding_interval_size > 0
else max_seqlen_k
)
)
return (
"decode_fast_path" if self.use_block_table else "varlen",
q_tokens,
max_kv_read,
1 if self.use_block_table else max_seqlen_q,
max_seqlen_k_signature,
)

def get_graph(
self,
batch_data: dict[str, Any],
q_padding_interval_size: int = 0,
kv_padding_interval_size: int = 0,
) -> torch.cuda.CUDAGraph | None:
graph_signature = self.get_graph_signature(batch_data, q_padding_interval_size, kv_padding_interval_size)
graph = self.graphs.get_graph(graph_signature)
# If this point is reached, it means the next step will be a new graph capture
if graph is None:
self.graphs.plan_for_new_graph()
logger.info(f"Creating graph for {(self.num_q_tokens, self.max_kv_read) = }")
logger.info(f"Creating graph for {graph_signature = }")
return graph

def set_graph(self, graph: torch.cuda.CUDAGraph) -> None:
self.graphs.set_graph(self.num_q_tokens, self.max_kv_read, graph)
def set_graph(
self,
graph: torch.cuda.CUDAGraph,
batch_data: dict[str, Any],
q_padding_interval_size: int = 0,
kv_padding_interval_size: int = 0,
) -> None:
self.graphs.set_graph(
self.get_graph_signature(batch_data, q_padding_interval_size, kv_padding_interval_size), graph
)


class HostDeviceIOPair:
Expand Down Expand Up @@ -663,6 +725,14 @@ def __init__(
def get_cumulative_seqlens(self) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
return self.io_pairs[self.current_pair].host_io.get_cumulative_seqlens()

@property
def num_q_tokens(self) -> int:
return self.io_pairs[self.current_pair].host_io.num_q_tokens

@property
def max_kv_read(self) -> int:
return self.io_pairs[self.current_pair].host_io.max_kv_read

# The prepare_batch_tensor method also has to prepare the carry over ids
def prepare_batch_tensors(
self,
Expand Down Expand Up @@ -742,11 +812,26 @@ def output_ids(self) -> torch.Tensor:
# The output ids are used to copy_ the infered tokens: they need to be on the device
return self.io_pairs[self.current_pair].device_io.output_ids

def get_graph(self) -> torch.cuda.CUDAGraph | None:
return self.io_pairs[self.current_pair].device_io.get_graph()
def get_graph(
self,
batch_data: dict[str, Any],
q_padding_interval_size: int = 0,
kv_padding_interval_size: int = 0,
) -> torch.cuda.CUDAGraph | None:
return self.io_pairs[self.current_pair].device_io.get_graph(
batch_data, q_padding_interval_size, kv_padding_interval_size
)

def set_graph(self, graph: torch.cuda.CUDAGraph) -> None:
self.io_pairs[self.current_pair].device_io.set_graph(graph)
def set_graph(
self,
graph: torch.cuda.CUDAGraph,
batch_data: dict[str, Any],
q_padding_interval_size: int = 0,
kv_padding_interval_size: int = 0,
) -> None:
self.io_pairs[self.current_pair].device_io.set_graph(
graph, batch_data, q_padding_interval_size, kv_padding_interval_size
)

@property
def use_block_table(self) -> bool:
Expand Down
18 changes: 11 additions & 7 deletions src/transformers/generation/continuous_batching/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,29 @@
from .requests import FutureRequestState, RequestState, RequestStatus, logger


GraphSignature = tuple[Any, ...]


class CudaGraphBuffer:
"""A fixed-size dict for CUDA graphs with LRU eviction when full."""

def __init__(self, max_size: int) -> None:
if max_size <= 0:
raise ValueError(f"max_size must be positive, but got {max_size}")
self.max_size = max_size
self._storage: OrderedDict[tuple[int, int], torch.cuda.CUDAGraph] = OrderedDict()
self._storage: OrderedDict[GraphSignature, torch.cuda.CUDAGraph] = OrderedDict()

def __del__(self) -> None:
original_max_size = self.max_size
self.max_size = 1 # 0 would cause an infinite loop, 1 is enough to clear all graphs
self.plan_for_new_graph(silent=True)
self.max_size = original_max_size

def get_graph(self, q_len: int, kv_len: int) -> torch.cuda.CUDAGraph | None:
graph = self._storage.get((q_len, kv_len))
def get_graph(self, key: GraphSignature) -> torch.cuda.CUDAGraph | None:
# Replay safety sometimes depends on scalar kwargs as well as tensor extents, so the key is generic.
graph = self._storage.get(key)
if graph is not None:
self._storage.move_to_end((q_len, kv_len))
self._storage.move_to_end(key)
return graph

def plan_for_new_graph(self, silent: bool = False) -> None:
Expand All @@ -50,11 +54,11 @@ def plan_for_new_graph(self, silent: bool = False) -> None:
logger.info(f"Evicting graph for {evicted_key = }")
evicted_graph.reset()

def set_graph(self, q_len: int, kv_len: int, graph: torch.cuda.CUDAGraph) -> None:
def set_graph(self, key: GraphSignature, graph: torch.cuda.CUDAGraph) -> None:
# In our use case, this should not have any effect because we plan for a new graph before it is captured
self.plan_for_new_graph()
logger.info(f"Setting graph for {q_len = }, {kv_len = }")
self._storage[(q_len, kv_len)] = graph
logger.info(f"Setting graph for {key = }")
self._storage[key] = graph


def attn_mask_is_needed(config: PretrainedConfig) -> bool:
Expand Down
Loading
Loading