diff --git a/lmdeploy/cli/cli.py b/lmdeploy/cli/cli.py index e0702aeeaf..842379c548 100644 --- a/lmdeploy/cli/cli.py +++ b/lmdeploy/cli/cli.py @@ -70,6 +70,7 @@ def add_parser_chat(): tb_group._group_actions.append(cache_max_entry_act) tb_group._group_actions.append(prefix_caching_act) tb_group._group_actions.append(quant_policy) + ArgumentHelper.linear_prefix_cache_interval_blocks(tb_group) ArgumentHelper.model_format(tb_group) ArgumentHelper.rope_scaling_factor(tb_group) ArgumentHelper.communicator(tb_group) diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index bb60d0b764..02d705c7c8 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -116,6 +116,7 @@ def add_parser_api_server(): cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group) cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group) prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group) + linear_prefix_cache_interval_blocks_act = ArgumentHelper.linear_prefix_cache_interval_blocks(pt_group) max_prefill_token_num_act = ArgumentHelper.max_prefill_token_num(pt_group) quant_policy = ArgumentHelper.quant_policy(pt_group) model_format = ArgumentHelper.model_format(pt_group) @@ -141,6 +142,7 @@ def add_parser_api_server(): tb_group._group_actions.append(cache_max_entry_act) tb_group._group_actions.append(cache_block_seq_len_act) tb_group._group_actions.append(prefix_caching_act) + tb_group._group_actions.append(linear_prefix_cache_interval_blocks_act) tb_group._group_actions.append(max_prefill_token_num_act) tb_group._group_actions.append(quant_policy) tb_group._group_actions.append(model_format) @@ -209,13 +211,14 @@ def api_server(args): """Serve LLMs with restful api using fastapi.""" from lmdeploy.archs import autoget_backend - max_batch_size = args.max_batch_size if args.max_batch_size \ - else get_max_batch_size(args.device) backend = args.backend if backend != 'pytorch': # set auto backend mode backend = autoget_backend(args.model_path) + max_batch_size = args.max_batch_size if args.max_batch_size \ + else get_max_batch_size(args.device) + if backend == 'pytorch': from lmdeploy.messages import PytorchEngineConfig adapters = get_lora_adapters(args.adapters) @@ -268,6 +271,7 @@ def api_server(args): cache_max_entry_count=args.cache_max_entry_count, cache_block_seq_len=args.cache_block_seq_len, enable_prefix_caching=args.enable_prefix_caching, + linear_prefix_cache_interval_blocks=args.linear_prefix_cache_interval_blocks, max_prefill_token_num=args.max_prefill_token_num, num_tokens_per_iter=args.num_tokens_per_iter, max_prefill_iters=args.max_prefill_iters, diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 39469c6f07..f399c8416c 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -568,6 +568,17 @@ def enable_prefix_caching(parser): default=False, help='Enable cache and match prefix') + @staticmethod + def linear_prefix_cache_interval_blocks(parser): + """Add argument linear_prefix_cache_interval_blocks to parser.""" + + return parser.add_argument('--linear-prefix-cache-interval-blocks', + type=int, + default=64, + help='Hybrid linear-attention prefix checkpoint interval in ' + 'KV cache blocks. Larger values reduce GDN checkpoint memory ' + 'usage but increase recompute after a prefix hit') + @staticmethod def num_tokens_per_iter(parser): return parser.add_argument('--num-tokens-per-iter', diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 01b8ec2d41..d6fe07b5f5 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -229,6 +229,11 @@ class TurbomindEngineConfig: a k/v block, default to 64 enable_prefix_caching: enable cache prompts for block reuse, default to False + linear_prefix_cache_interval_blocks: hybrid linear-attention prefix + checkpoint interval, in KV cache blocks. Larger values reduce GDN + checkpoint memory but may require more recompute after a prefix + hit. Applies only to hybrid models with prefix caching enabled. + Default to 64 quant_policy: default to 0. When k/v is quantized into 4 or 8 bit, set it to 4 or 8, respectively rope_scaling_factor: scaling factor used for dynamic ntk, @@ -278,6 +283,7 @@ class TurbomindEngineConfig: cache_chunk_size: int = -1 cache_block_seq_len: int = 64 enable_prefix_caching: bool = False + linear_prefix_cache_interval_blocks: int = 64 quant_policy: int = 0 rope_scaling_factor: float = 0.0 use_logn_attn: bool = False @@ -298,6 +304,8 @@ def __post_init__(self): assert self.dtype in ['auto', 'float16', 'bfloat16'] assert self.tp >= 1, 'tp must be a positive integer' assert self.cache_max_entry_count > 0, 'invalid cache_max_entry_count' + if self.linear_prefix_cache_interval_blocks < 1: + raise ValueError('invalid linear_prefix_cache_interval_blocks') assert self.quant_policy in (0, 4, 8), 'invalid quant_policy' assert self.rope_scaling_factor >= 0, 'invalid rope_scaling_factor' assert self.max_prefill_token_num >= 0, \ diff --git a/src/turbomind/engine/engine.cc b/src/turbomind/engine/engine.cc index 391a034dae..b3b7af23f3 100644 --- a/src/turbomind/engine/engine.cc +++ b/src/turbomind/engine/engine.cc @@ -230,6 +230,7 @@ void Engine::Impl::CreateSequenceManager() param_.cache_max_block_count, param_.cache_chunk_size, param_.enable_prefix_caching, + param_.linear_prefix_cache_interval_blocks, tp_rank_, param_.attn_cp_size, core::Context::alloc(kDEVICE), diff --git a/src/turbomind/models/llama/BlockTrie.cc b/src/turbomind/models/llama/BlockTrie.cc index 4046741bd2..a61c194127 100644 --- a/src/turbomind/models/llama/BlockTrie.cc +++ b/src/turbomind/models/llama/BlockTrie.cc @@ -1,6 +1,7 @@ // Copyright (c) OpenMMLab. All rights reserved. #include "src/turbomind/models/llama/BlockTrie.h" +#include "src/turbomind/core/logger.h" #include "src/turbomind/models/llama/SequenceManager.h" namespace turbomind { @@ -14,19 +15,46 @@ size_t hash(const std::vector& vec) return seed; } -BlockTrie::BlockTrie(size_t block_len, std::shared_ptr block_manager): - block_seq_len_(block_len), block_manager_(block_manager) +BlockTrie::BlockTrie(size_t block_len, + std::shared_ptr block_manager, + int linear_prefix_cache_interval_blocks, + int linear_state_slot_capacity, + std::vector conv_state_shape, + DataType conv_state_dtype, + std::vector recurrent_state_shape, + DataType recurrent_state_dtype): + block_seq_len_(block_len), + block_manager_(block_manager), + linear_prefix_cache_interval_blocks_(linear_prefix_cache_interval_blocks), + linear_prefix_cache_interval_tokens_(linear_prefix_cache_interval_blocks * block_len), + linear_state_slot_capacity_(linear_state_slot_capacity), + conv_state_shape_(std::move(conv_state_shape)), + recurrent_state_shape_(std::move(recurrent_state_shape)), + conv_state_dtype_(conv_state_dtype), + recurrent_state_dtype_(recurrent_state_dtype) { root_ = std::make_shared(); + + if (linear_state_slot_capacity_ > 0) { + TM_CHECK_GT(linear_prefix_cache_interval_blocks_, 0); + TM_CHECK(!conv_state_shape_.empty()); + TM_CHECK(!recurrent_state_shape_.empty()); + linear_conv_states_.resize(linear_state_slot_capacity_); + linear_recurrent_states_.resize(linear_state_slot_capacity_); + free_linear_state_slots_.reserve(linear_state_slot_capacity_); + for (int slot = linear_state_slot_capacity_ - 1; slot >= 0; --slot) { + free_linear_state_slots_.push_back(slot); + } + } } -std::tuple BlockTrie::Match(const Sequence& seq) +BlockTrieMatch BlockTrie::Match(const Sequence& seq) { - BlockIds block_ids; - UniqueIds unique_ids; + BlockTrieMatch match; - auto node = root_; - auto first = seq.prompt.begin(); + auto node = root_; + auto first = seq.prompt.begin(); + auto linear_prefix_state = root_; // Warning: Do not use "<=" operator even when seq.prompt length is evenly // divisible by block_seq_len_. The model needs at least one input token to generate output. @@ -35,9 +63,12 @@ std::tuple BlockTrie::Match(const Sequence& seq) const size_t hash_key = hash(segment); if (const auto it = node->children.find(hash_key); it != node->children.end()) { if (segment == it->second->tokens) { - block_ids.push_back(it->second->block_id); - unique_ids.push_back(it->second->block_unique_id); + match.block_ids.push_back(it->second->block_id); + match.unique_ids.push_back(it->second->block_unique_id); node = it->second; + if (node->linear_state_slot >= 0) { + linear_prefix_state = node; + } first += block_seq_len_; } else { @@ -50,7 +81,14 @@ std::tuple BlockTrie::Match(const Sequence& seq) } } - return std::make_tuple(block_ids, unique_ids); + if (linear_prefix_state != root_) { + const int slot = linear_prefix_state->linear_state_slot; + match.linear_cache_len = linear_prefix_state->num_matched; + match.conv_states = LinearConvState(slot); + match.recurrent_states = LinearRecurrentState(slot); + } + + return match; } std::tuple BlockTrie::Cache(const Sequence& seq, const std::vector& tokens) @@ -66,7 +104,10 @@ std::tuple BlockTrie::Cache(const Sequence& seq, const std: const int n_blocks = std::min(seq.cache_len, (int)tokens.size()) / block_seq_len_; - int new_cached = 0; + int new_cached = 0; + const int checkpoint_base = linear_prefix_cache_interval_tokens_ ? + seq.pending_linear_prefix_capture_base_len / linear_prefix_cache_interval_tokens_ : + 0; for (int idx = 0; idx < n_blocks; ++idx) { auto start = tokens.begin() + idx * block_seq_len_; @@ -75,14 +116,16 @@ std::tuple BlockTrie::Cache(const Sequence& seq, const std: const std::vector segment(start, end); const size_t hash_key = hash(segment); // TODO(lvhan): add salt to ensure the hash security - int block_id = seq.blocks[idx]; - uint64_t block_unique_id = seq.block_unique_ids[idx]; + int block_id = seq.blocks[idx]; + uint64_t block_unique_id = seq.block_unique_ids[idx]; + const int num_matched = (idx + 1) * block_seq_len_; if (auto it = node->children.find(hash_key); it != node->children.end()) { if (segment == it->second->tokens) { // fast-forward node = it->second; node->block_id = block_id; node->block_unique_id = block_unique_id; + node->num_matched = num_matched; } else { TM_LOG_WARN("Hash collision detected"); @@ -96,8 +139,24 @@ std::tuple BlockTrie::Cache(const Sequence& seq, const std: node->tokens = segment; node->block_id = block_id; node->block_unique_id = block_unique_id; + node->num_matched = num_matched; new_cached += block_seq_len_; } + if (IsLinearCheckpointNode(num_matched)) { + const int checkpoint_idx = num_matched / linear_prefix_cache_interval_tokens_ - checkpoint_base - 1; + if (checkpoint_idx >= 0 && checkpoint_idx < seq.pending_linear_prefix_capture_count + && seq.pending_linear_prefix_conv_states && seq.pending_linear_prefix_recurrent_states) { + if (node->linear_state_slot < 0) { + node->linear_state_slot = AcquireLinearStateSlot(); + } + if (node->linear_state_slot >= 0) { + Copy(seq.pending_linear_prefix_conv_states.slice(checkpoint_idx).squeeze(0), + LinearConvState(node->linear_state_slot)); + Copy(seq.pending_linear_prefix_recurrent_states.slice(checkpoint_idx).squeeze(0), + LinearRecurrentState(node->linear_state_slot)); + } + } + } cache_block_ids.emplace_back(block_id); cache_block_unique_ids.emplace_back(block_unique_id); } @@ -117,6 +176,7 @@ void BlockTrie::DFS(std::shared_ptr& node) for (auto it = node->children.begin(); it != node->children.end();) { if (block_manager_->unique_id(it->second->block_id) != it->second->block_unique_id) { // child invalid + ReleaseLinearPrefixState(it->second); it = node->children.erase(it); } else { @@ -126,4 +186,77 @@ void BlockTrie::DFS(std::shared_ptr& node) } } +void BlockTrie::ReleaseLinearPrefixState(std::shared_ptr& node) +{ + if (!node) { + return; + } + for (auto& [_, child] : node->children) { + ReleaseLinearPrefixState(child); + } + ReleaseLinearStateSlot(node->linear_state_slot); + node->linear_state_slot = -1; +} + +bool BlockTrie::IsLinearCheckpointNode(int num_matched) const +{ + return linear_prefix_cache_interval_tokens_ > 0 && num_matched > 0 + && num_matched % linear_prefix_cache_interval_tokens_ == 0; +} + +int BlockTrie::AcquireLinearStateSlot() +{ + if (free_linear_state_slots_.empty()) { + if (!warned_linear_state_pool_exhausted_) { + TM_LOG_WARN("[BlockTrie] linear prefix checkpoint pool exhausted; deeper hybrid prefix checkpoints " + "will be skipped until cached entries are evicted"); + warned_linear_state_pool_exhausted_ = true; + } + return -1; + } + const int slot = free_linear_state_slots_.back(); + free_linear_state_slots_.pop_back(); + try { + if (!linear_conv_states_[slot]) { + linear_conv_states_[slot] = {conv_state_shape_, conv_state_dtype_, kDEVICE}; + } + if (!linear_recurrent_states_[slot]) { + linear_recurrent_states_[slot] = {recurrent_state_shape_, recurrent_state_dtype_, kDEVICE}; + } + } + catch (const std::exception& e) { + free_linear_state_slots_.push_back(slot); + if (!warned_linear_state_pool_oom_) { + TM_LOG_WARN("[BlockTrie] failed to allocate hybrid prefix checkpoint state: {}. " + "Further GDN prefix checkpoints will be skipped until memory is freed.", + e.what()); + warned_linear_state_pool_oom_ = true; + } + return -1; + } + return slot; +} + +void BlockTrie::ReleaseLinearStateSlot(int slot) +{ + if (slot >= 0) { + // Keep tensors allocated so recycled slots avoid repeated GPU allocations. + free_linear_state_slots_.push_back(slot); + } +} + +Tensor BlockTrie::LinearConvState(int slot) const +{ + TM_CHECK_GE(slot, 0); + TM_CHECK_LT(slot, (int)linear_conv_states_.size()); + return linear_conv_states_[slot]; +} + +Tensor BlockTrie::LinearRecurrentState(int slot) const +{ + TM_CHECK_GE(slot, 0); + TM_CHECK_LT(slot, (int)linear_recurrent_states_.size()); + return linear_recurrent_states_[slot]; +} + } // namespace turbomind diff --git a/src/turbomind/models/llama/BlockTrie.h b/src/turbomind/models/llama/BlockTrie.h index 75381c3bdd..441529aa29 100644 --- a/src/turbomind/models/llama/BlockTrie.h +++ b/src/turbomind/models/llama/BlockTrie.h @@ -2,6 +2,7 @@ #pragma once +#include "src/turbomind/core/core.h" #include "src/turbomind/models/llama/BlockManager.h" #include #include @@ -18,11 +19,27 @@ struct TrieNode { int block_id; uint64_t block_unique_id; int num_matched; + int linear_state_slot = -1; +}; + +struct BlockTrieMatch { + BlockIds block_ids; + UniqueIds unique_ids; + int linear_cache_len = 0; + Tensor conv_states; + Tensor recurrent_states; }; class BlockTrie { public: - explicit BlockTrie(size_t block_len, std::shared_ptr block_manager); + explicit BlockTrie(size_t block_len, + std::shared_ptr block_manager, + int linear_prefix_cache_interval_blocks = 0, + int linear_state_slot_capacity = 0, + std::vector conv_state_shape = {}, + DataType conv_state_dtype = {}, + std::vector recurrent_state_shape = {}, + DataType recurrent_state_dtype = {}); /** * @brief Attempt to match cached key-value (KV) blocks for a given sequence. @@ -38,7 +55,7 @@ class BlockTrie { * * @note If no blocks are matched, all containers in the returned tuple will be empty. */ - std::tuple Match(const Sequence& seq); + BlockTrieMatch Match(const Sequence& seq); /** * @brief Cache the key-value (KV) blocks of a given sequence. @@ -61,7 +78,13 @@ class BlockTrie { void Verify(); private: - void DFS(std::shared_ptr& node); + void DFS(std::shared_ptr& node); + void ReleaseLinearPrefixState(std::shared_ptr& node); + bool IsLinearCheckpointNode(int num_matched) const; + int AcquireLinearStateSlot(); + void ReleaseLinearStateSlot(int slot); + Tensor LinearConvState(int slot) const; + Tensor LinearRecurrentState(int slot) const; private: size_t block_seq_len_; @@ -69,6 +92,19 @@ class BlockTrie { std::shared_ptr block_manager_; std::shared_ptr root_; + + int linear_prefix_cache_interval_blocks_{}; + int linear_prefix_cache_interval_tokens_{}; + int linear_state_slot_capacity_{}; + std::vector conv_state_shape_; + std::vector recurrent_state_shape_; + DataType conv_state_dtype_{}; + DataType recurrent_state_dtype_{}; + std::vector linear_conv_states_; + std::vector linear_recurrent_states_; + std::vector free_linear_state_slots_; + bool warned_linear_state_pool_exhausted_{false}; + bool warned_linear_state_pool_oom_{false}; }; } // namespace turbomind diff --git a/src/turbomind/models/llama/GatedDeltaNetLayer.cc b/src/turbomind/models/llama/GatedDeltaNetLayer.cc index e5a7459143..7be70d94ac 100644 --- a/src/turbomind/models/llama/GatedDeltaNetLayer.cc +++ b/src/turbomind/models/llama/GatedDeltaNetLayer.cc @@ -24,9 +24,12 @@ GatedDeltaNetLayer::GatedDeltaNetLayer(const ModelParam& model, key_dim_(num_k_heads_ * key_head_dim_), value_dim_(num_v_heads_ * value_head_dim_), conv_dim_(key_dim_ * 2 + value_dim_), + enable_linear_prefix_caching_(engine.enable_prefix_caching), + linear_prefix_cache_interval_tokens_(attn.cache_block_seq_len * engine.linear_prefix_cache_interval_blocks), norm_eps_(model.norm_eps), dtype_(model.data_type), state_dtype_(model.linear_state_dtype), + is_warm_up_(*ctx.is_warm_up), linear_(*ctx.linear) { layer_types_ = model.layer_types; @@ -45,10 +48,17 @@ GatedDeltaNetLayer::GatedDeltaNetLayer(const ModelParam& model, conv_dim_, d_conv_, num_linear_layers_); - if (num_linear_layers_ > 0) { + prefix_capture_state_bytes_ = static_cast(num_linear_layers_) * d_conv_ * conv_dim_ * byte_size(dtype_) + + static_cast(num_linear_layers_) * num_v_heads_ * key_head_dim_ + * value_head_dim_ * byte_size(state_dtype_); conv_state_ptrs_buf_ = {engine.max_batch_size, kCPUpinned}; recurrent_state_ptrs_buf_ = {engine.max_batch_size, kCPUpinned}; + if (enable_linear_prefix_caching_) { + TM_CHECK_GT(linear_prefix_cache_interval_tokens_, 0); + conv_capture_ptrs_buf_ = {engine.max_batch_size, kCPUpinned}; + recurrent_capture_ptrs_buf_ = {engine.max_batch_size, kCPUpinned}; + } } for (int i = 0; i < phases; ++i) { @@ -56,6 +66,10 @@ GatedDeltaNetLayer::GatedDeltaNetLayer(const ModelParam& model, if (num_linear_layers_ > 0) { data_.at(i).conv_state_ptrs = empty_like(conv_state_ptrs_buf_, kDEVICE); data_.at(i).recurrent_state_ptrs = empty_like(recurrent_state_ptrs_buf_, kDEVICE); + if (enable_linear_prefix_caching_) { + data_.at(i).conv_capture_ptrs = empty_like(conv_capture_ptrs_buf_, kDEVICE); + data_.at(i).recurrent_capture_ptrs = empty_like(recurrent_capture_ptrs_buf_, kDEVICE); + } } } @@ -91,6 +105,25 @@ void GatedDeltaNetLayer::Run(BatchOp op, int phase, TensorMap& env) d.q_offsets = env.at("q_offsets").buffer().borrow(); d.k_offsets = env.at("k_offsets").buffer().borrow(); } + else if (op == BatchOp::kUpdate) { + auto& d = data_.at(phase); + for (int i = 0; i < d.batch_size; ++i) { + auto& s = *d.seqs[i]; + if (d.capture_counts.empty() || d.capture_counts[i] == 0) { + s.pending_linear_prefix_conv_states = {}; + s.pending_linear_prefix_recurrent_states = {}; + s.pending_linear_prefix_capture_count = 0; + s.pending_linear_prefix_capture_base_len = 0; + continue; + } + s.pending_linear_prefix_conv_states = + d.conv_prefix_checkpoints.slice(d.capture_offsets[i], d.capture_counts[i]); + s.pending_linear_prefix_recurrent_states = + d.recurrent_prefix_checkpoints.slice(d.capture_offsets[i], d.capture_counts[i]); + s.pending_linear_prefix_capture_count = d.capture_counts[i]; + s.pending_linear_prefix_capture_base_len = d.history_lens[i]; + } + } } void GatedDeltaNetLayer::Setup(int phase, TensorMap& env) @@ -99,17 +132,22 @@ void GatedDeltaNetLayer::Setup(int phase, TensorMap& env) const auto& b = *env.at("batch").data()[0]; d.batch_size = b.rc.size(); - d.rc.resize(d.batch_size); + d.seqs.resize(d.batch_size); d.input_lens.resize(d.batch_size); + d.history_lens.resize(d.batch_size); + d.capture_counts.assign(d.batch_size, 0); + d.capture_offsets.assign(d.batch_size + 1, 0); + d.total_capture_count = 0; d.conv_states.resize(d.batch_size); d.recurrent_states.resize(d.batch_size); for (int i = 0; i < d.batch_size; ++i) { - d.rc[i] = b.rc[i].get(); - d.input_lens[i] = b.rc[i]->input_len; + d.seqs[i] = b.rc[i]->seq; + d.input_lens[i] = b.rc[i]->input_len; + d.history_lens[i] = b.rc[i]->history_len; - auto& s = *b.rc[i]->seq; + auto& s = *d.seqs[i]; TM_CHECK(s.conv_states && s.recurrent_states) << "Linear-attention state slot is not bound for sequence " << s.id; if (s.linear_states_need_reset) { @@ -127,10 +165,114 @@ void GatedDeltaNetLayer::Setup(int phase, TensorMap& env) conv_state_ptrs_buf_[i] = d.conv_states[i].raw_data(); recurrent_state_ptrs_buf_[i] = d.recurrent_states[i].raw_data(); + + if (enable_linear_prefix_caching_) { + int capture_count = 0; + // Warm-up requests never publish prefix cache entries, so avoid + // provisioning large checkpoint buffers for synthetic prefills. + if (!is_warm_up_ && !s.prompt.empty() && d.input_lens[i] > 0) { + const int prompt_limit = std::min(s.prompt.size(), d.history_lens[i] + d.input_lens[i]); + if (prompt_limit > d.history_lens[i]) { + capture_count = prompt_limit / linear_prefix_cache_interval_tokens_ + - d.history_lens[i] / linear_prefix_cache_interval_tokens_; + } + } + d.capture_counts[i] = capture_count; + d.capture_offsets[i] = d.total_capture_count; + d.total_capture_count += capture_count; + } + } + d.capture_offsets[d.batch_size] = d.total_capture_count; + + if (enable_linear_prefix_caching_) { + if (!EnsurePrefixCaptureCapacity(d, d.total_capture_count)) { + std::fill(d.capture_counts.begin(), d.capture_counts.end(), 0); + std::fill(d.capture_offsets.begin(), d.capture_offsets.end(), 0); + d.total_capture_count = 0; + } + for (int i = 0; i < d.batch_size; ++i) { + const int capture_count = d.capture_counts[i]; + conv_capture_ptrs_buf_[i] = + capture_count ? d.conv_prefix_checkpoints.slice(d.capture_offsets[i], capture_count).raw_data() : + nullptr; + recurrent_capture_ptrs_buf_[i] = + capture_count ? d.recurrent_prefix_checkpoints.slice(d.capture_offsets[i], capture_count).raw_data() : + nullptr; + } } Copy(conv_state_ptrs_buf_, d.batch_size, d.conv_state_ptrs); Copy(recurrent_state_ptrs_buf_, d.batch_size, d.recurrent_state_ptrs); + if (enable_linear_prefix_caching_) { + Copy(conv_capture_ptrs_buf_, d.batch_size, d.conv_capture_ptrs); + Copy(recurrent_capture_ptrs_buf_, d.batch_size, d.recurrent_capture_ptrs); + } +} + +bool GatedDeltaNetLayer::EnsurePrefixCaptureCapacity(Data& d, int capture_count) +{ + if (capture_count <= 0) { + return true; + } + if (d.conv_prefix_checkpoints && d.conv_prefix_checkpoints.shape(0) >= capture_count + && d.recurrent_prefix_checkpoints && d.recurrent_prefix_checkpoints.shape(0) >= capture_count) { + return true; + } + if (!CanAllocatePrefixCapture(capture_count)) { + return false; + } + try { + d.conv_prefix_checkpoints = {{capture_count, num_linear_layers_, d_conv_, conv_dim_}, dtype_, kDEVICE}; + d.recurrent_prefix_checkpoints = { + {capture_count, num_linear_layers_, num_v_heads_, key_head_dim_, value_head_dim_}, state_dtype_, kDEVICE}; + return true; + } + catch (const std::exception& e) { + if (!warned_prefix_capture_oom_) { + TM_LOG_WARN("[GDN] failed to allocate hybrid prefix capture staging for {} slots: {}. " + "This batch will run without storing new GDN prefix checkpoints.", + capture_count, + e.what()); + warned_prefix_capture_oom_ = true; + } + return false; + } +} + +bool GatedDeltaNetLayer::CanAllocatePrefixCapture(int capture_count) +{ + if (capture_count <= 0 || prefix_capture_state_bytes_ == 0) { + return true; + } + + // Prefix checkpoints are opportunistic acceleration data. Keep a generous + // safety margin so long-context requests continue running even if we must + // skip storing new GDN checkpoints for that batch. + constexpr size_t kPrefixCaptureMaxBytes = size_t{1} << 30; // 1 GiB + constexpr size_t kPrefixCaptureSafetyBytes = size_t{256} << 20; // 256 MiB + + size_t free_bytes{}; + size_t total_bytes{}; + check_cuda_error(cudaMemGetInfo(&free_bytes, &total_bytes)); + (void)total_bytes; + + const size_t budget_bytes = free_bytes > kPrefixCaptureSafetyBytes ? + std::min(free_bytes - kPrefixCaptureSafetyBytes, kPrefixCaptureMaxBytes) : + size_t{0}; + const size_t requested_bytes = static_cast(capture_count) * prefix_capture_state_bytes_; + + if (requested_bytes > budget_bytes) { + if (!warned_prefix_capture_budget_) { + TM_LOG_WARN("[GDN] skipping hybrid prefix checkpoint capture for this batch: requested {:.2f} MB, " + "budget {:.2f} MB. Prefix caching remains enabled, but this batch will not store new " + "linear-attention checkpoints.", + requested_bytes / (1024.0 * 1024.0), + budget_bytes / (1024.0 * 1024.0)); + warned_prefix_capture_budget_ = true; + } + return false; + } + return true; } static int linear_layer_index(int layer_id, const std::vector& layer_types) @@ -202,25 +344,47 @@ void GatedDeltaNetLayer::Forward(ForwardParam p) Tensor attn_out{{token_num, value_dim_}, dtype, device}; Tensor conv_out{{token_num, conv_dim_}, dtype, device}; - const int state_layer_idx = linear_layer_index(p.layer_id, layer_types_); - const int conv_state_layer_offset = state_layer_idx * (conv_dim_ * d_conv_); - const int recurrent_state_layer_offset = state_layer_idx * (num_v_heads_ * key_head_dim_ * value_head_dim_); + const int state_layer_idx = linear_layer_index(p.layer_id, layer_types_); + const int conv_state_layer_offset = state_layer_idx * (conv_dim_ * d_conv_); + const int recurrent_state_layer_offset = state_layer_idx * (num_v_heads_ * key_head_dim_ * value_head_dim_); + const int conv_capture_stride = num_linear_layers_ * d_conv_ * conv_dim_; + const int recurrent_capture_stride = num_linear_layers_ * num_v_heads_ * key_head_dim_ * value_head_dim_; + const bool has_prefix_captures = enable_linear_prefix_caching_ && pd.total_capture_count > 0; // ----- 3a. Fused Causal Conv1d + SiLU (all requests) ----- // all_proj carries the non-contiguous qkv slice (stride = all_col); // in_stride is derived from all_proj.stride(0) inside the launcher. - invokeFusedConv1dSiLU(conv_out, - all_proj, - weights.conv1d, - Tensor{}, - pd.conv_state_ptrs, - pd.q_offsets, - pd.k_offsets, - pd.batch_size, - conv_state_layer_offset, - sm_count_, - work_counter_.data(), - stream); + if (has_prefix_captures) { + invokeFusedConv1dSiLU(conv_out, + all_proj, + weights.conv1d, + Tensor{}, + pd.conv_state_ptrs, + pd.q_offsets, + pd.k_offsets, + pd.batch_size, + conv_state_layer_offset, + sm_count_, + work_counter_.data(), + stream, + pd.conv_capture_ptrs, + conv_capture_stride, + linear_prefix_cache_interval_tokens_); + } + else { + invokeFusedConv1dSiLU(conv_out, + all_proj, + weights.conv1d, + Tensor{}, + pd.conv_state_ptrs, + pd.q_offsets, + pd.k_offsets, + pd.batch_size, + conv_state_layer_offset, + sm_count_, + work_counter_.data(), + stream); + } sync_check_cuda_error(); // ----- 3b. Gated Delta Rule ----- @@ -230,7 +394,7 @@ void GatedDeltaNetLayer::Forward(ForwardParam p) { int decode_count = 0; for (int i = 0; i < pd.batch_size; ++i) { - if (pd.input_lens[i] <= 1) + if (pd.input_lens[i] <= 1 && (!has_prefix_captures || pd.capture_counts[i] == 0)) ++decode_count; else break; @@ -262,19 +426,42 @@ void GatedDeltaNetLayer::Forward(ForwardParam p) // Prefill on aux stream (higher priority) auto pf_state = pd.recurrent_state_ptrs.slice(decode_count, prefill_count); auto pf_q = pd.q_offsets.slice(decode_count, prefill_count + 1); - invokeChunkedGatedDeltaRuleBatched(attn_out, - conv_out, - beta, - g, - pf_state, - pf_q, - prefill_count, - num_k_heads_, - recurrent_state_layer_offset, - state_dtype_, - sm_count_, - work_counter_.data(), - aux_stream_); + auto pf_k = pd.k_offsets.slice(decode_count, prefill_count + 1); + if (has_prefix_captures) { + auto pf_capture = pd.recurrent_capture_ptrs.slice(decode_count, prefill_count); + invokeChunkedGatedDeltaRuleBatched(attn_out, + conv_out, + beta, + g, + pf_state, + pf_q, + pf_k, + prefill_count, + num_k_heads_, + recurrent_state_layer_offset, + state_dtype_, + sm_count_, + work_counter_.data(), + aux_stream_, + pf_capture, + recurrent_capture_stride, + linear_prefix_cache_interval_tokens_); + } + else { + invokeChunkedGatedDeltaRuleBatched(attn_out, + conv_out, + beta, + g, + pf_state, + pf_q, + prefill_count, + num_k_heads_, + recurrent_state_layer_offset, + state_dtype_, + sm_count_, + work_counter_.data(), + aux_stream_); + } // Join: main stream waits for prefill to finish check_cuda_error(cudaEventRecord(ev_after_, aux_stream_)); @@ -300,20 +487,42 @@ void GatedDeltaNetLayer::Forward(ForwardParam p) else if (prefill_count > 0) { auto state_slice = pd.recurrent_state_ptrs.slice(decode_count, prefill_count); auto q_slice = pd.q_offsets.slice(decode_count, prefill_count + 1); - invokeChunkedGatedDeltaRuleBatched(attn_out, - conv_out, - beta, - g, - state_slice, - q_slice, - prefill_count, - num_k_heads_, - recurrent_state_layer_offset, - state_dtype_, - sm_count_, - work_counter_.data(), - stream); - // invokeChunkedGatedDeltaRuleBatched + auto k_slice = pd.k_offsets.slice(decode_count, prefill_count + 1); + if (has_prefix_captures) { + auto capture_slice = pd.recurrent_capture_ptrs.slice(decode_count, prefill_count); + invokeChunkedGatedDeltaRuleBatched(attn_out, + conv_out, + beta, + g, + state_slice, + q_slice, + k_slice, + prefill_count, + num_k_heads_, + recurrent_state_layer_offset, + state_dtype_, + sm_count_, + work_counter_.data(), + stream, + capture_slice, + recurrent_capture_stride, + linear_prefix_cache_interval_tokens_); + } + else { + invokeChunkedGatedDeltaRuleBatched(attn_out, + conv_out, + beta, + g, + state_slice, + q_slice, + prefill_count, + num_k_heads_, + recurrent_state_layer_offset, + state_dtype_, + sm_count_, + work_counter_.data(), + stream); + } } } sync_check_cuda_error(); diff --git a/src/turbomind/models/llama/GatedDeltaNetLayer.h b/src/turbomind/models/llama/GatedDeltaNetLayer.h index 67e240c891..7c3a953f3f 100644 --- a/src/turbomind/models/llama/GatedDeltaNetLayer.h +++ b/src/turbomind/models/llama/GatedDeltaNetLayer.h @@ -9,6 +9,8 @@ namespace turbomind { +struct Sequence; + class GatedDeltaNetLayer { public: struct ForwardParam { @@ -33,7 +35,11 @@ class GatedDeltaNetLayer { void Forward(ForwardParam p); private: + struct Data; + void Setup(int phase, TensorMap& env); + bool EnsurePrefixCaptureCapacity(Data& d, int capture_count); + bool CanAllocatePrefixCapture(int capture_count); // Model dimensions int hidden_units_; @@ -47,30 +53,44 @@ class GatedDeltaNetLayer { int conv_dim_; // key_dim * 2 + value_dim int num_linear_layers_; // count of linear attention layers for state sizing std::vector layer_types_; // model layer types for index mapping + bool enable_linear_prefix_caching_{false}; + int linear_prefix_cache_interval_tokens_{0}; float norm_eps_; DataType dtype_; DataType state_dtype_; // recurrent state dtype (may differ from dtype_ for float32 state) + int& is_warm_up_; + size_t prefix_capture_state_bytes_{0}; LlamaLinear& linear_; // Per-phase batch data (mirrors UnifiedAttentionLayer pattern) struct Data { - std::vector rc; // borrowed batch RequestCache pointers - std::vector input_lens; // snapshot of input_len per request (captured at Setup time) - int batch_size = 0; - Buffer_ q_offsets; // cumulative input-token offsets, device buffer - Buffer_ k_offsets; // cumulative key (history+input) offsets, device buffer - std::vector conv_states; - std::vector recurrent_states; - Buffer_ conv_state_ptrs; - Buffer_ recurrent_state_ptrs; + std::vector seqs; // borrowed live sequence pointers + std::vector input_lens; // snapshot of input_len per request (captured at Setup time) + std::vector history_lens; + std::vector capture_counts; + std::vector capture_offsets; + int batch_size = 0; + int total_capture_count = 0; + Buffer_ q_offsets; // cumulative input-token offsets, device buffer + Buffer_ k_offsets; // cumulative key (history+input) offsets, device buffer + std::vector conv_states; + std::vector recurrent_states; + Buffer_ conv_state_ptrs; + Buffer_ recurrent_state_ptrs; + Buffer_ conv_capture_ptrs; + Buffer_ recurrent_capture_ptrs; + Tensor conv_prefix_checkpoints; + Tensor recurrent_prefix_checkpoints; }; std::vector data_; // staging buffers Buffer_ conv_state_ptrs_buf_; Buffer_ recurrent_state_ptrs_buf_; + Buffer_ conv_capture_ptrs_buf_; + Buffer_ recurrent_capture_ptrs_buf_; // Queried once at construction; passed to all three kernel launchers. int sm_count_{1}; @@ -80,6 +100,9 @@ class GatedDeltaNetLayer { cudaStream_t aux_stream_{}; cudaEvent_t ev_before_{}; // main→aux: prior work done cudaEvent_t ev_after_{}; // aux→main: prefill done + + bool warned_prefix_capture_oom_{false}; + bool warned_prefix_capture_budget_{false}; }; } // namespace turbomind diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc index ea2817727d..f2aaaedd81 100644 --- a/src/turbomind/models/llama/SequenceManager.cc +++ b/src/turbomind/models/llama/SequenceManager.cc @@ -31,6 +31,71 @@ std::string vector2string(const std::vector& data) return ss.str(); } +namespace hybrid_prefix_budget { + +size_t GetShapeBytes(const std::vector& shape, DataType dtype) +{ + if (shape.empty()) { + return 0; + } + + size_t numel = 1; + for (const auto dim : shape) { + TM_CHECK_GT(dim, 0); + numel *= static_cast(dim); + } + return static_cast(byte_size(dtype, static_cast(numel))); +} + +size_t GetHybridCheckpointSlotBytes(const std::vector& conv_state_shape, + DataType conv_state_dtype, + const std::vector& recurrent_state_shape, + DataType recurrent_state_dtype) +{ + return GetShapeBytes(conv_state_shape, conv_state_dtype) + + GetShapeBytes(recurrent_state_shape, recurrent_state_dtype); +} + +unsigned __int128 GetHybridCacheBytesForBlocks(size_t kv_block_count, + size_t kv_block_bytes, + size_t checkpoint_slot_bytes, + int checkpoint_interval_blocks) +{ + const size_t checkpoint_slots = checkpoint_interval_blocks > 0 ? kv_block_count / checkpoint_interval_blocks : 0; + return static_cast(kv_block_count) * kv_block_bytes + + static_cast(checkpoint_slots) * checkpoint_slot_bytes; +} + +size_t SolveHybridKvBlockCount(size_t budget_bytes, + size_t kv_block_bytes, + size_t checkpoint_slot_bytes, + int checkpoint_interval_blocks) +{ + if (!budget_bytes || !kv_block_bytes) { + return 0; + } + if (!checkpoint_slot_bytes || checkpoint_interval_blocks <= 0) { + return budget_bytes / kv_block_bytes; + } + + size_t lo = 0; + size_t hi = budget_bytes / kv_block_bytes; + while (lo < hi) { + const size_t mid = lo + (hi - lo + 1) / 2; + const auto required_bytes = + GetHybridCacheBytesForBlocks(mid, kv_block_bytes, checkpoint_slot_bytes, checkpoint_interval_blocks); + if (required_bytes <= budget_bytes) { + lo = mid; + } + else { + hi = mid - 1; + } + } + return lo; +} + +} // namespace hybrid_prefix_budget + SequenceManager::SequenceManager(const ModelParam& model_param, DataType runtime_dtype, int cache_block_seq_len, @@ -39,14 +104,20 @@ SequenceManager::SequenceManager(const ModelParam& model_param, double block_count, int chunk_size, bool enable_prefix_caching, + int linear_prefix_cache_interval_blocks, int rank, int attn_cp_size, core::Allocator allocator, GetFreeMemSize get_free_size): - block_seq_len_(cache_block_seq_len), rank_(rank), attn_cp_size_(attn_cp_size) + block_seq_len_(cache_block_seq_len), + rank_(rank), + attn_cp_size_(attn_cp_size), + linear_prefix_cache_interval_blocks_(linear_prefix_cache_interval_blocks), + linear_prefix_cache_interval_tokens_(cache_block_seq_len * linear_prefix_cache_interval_blocks) { TM_CHECK_GT(attn_tp_size, 0); TM_CHECK_GT(cache_block_seq_len, 0); + TM_CHECK_GE(linear_prefix_cache_interval_blocks, 1); int cache_layer_num = model_param.layer_num; int num_linear_layers = 0; @@ -57,7 +128,12 @@ SequenceManager::SequenceManager(const ModelParam& model_param, } } - const size_t free_before = (block_count < 1. && num_linear_layers > 0) ? get_free_size() : 0; + const size_t free_before = (block_count < 1. && num_linear_layers > 0) ? get_free_size() : 0; + size_t live_linear_bytes{}; + std::vector linear_prefix_conv_state_shape; + std::vector linear_prefix_recurrent_state_shape; + DataType linear_prefix_conv_state_dtype{}; + DataType linear_prefix_recurrent_state_dtype{}; if (num_linear_layers > 0) { @@ -77,6 +153,10 @@ SequenceManager::SequenceManager(const ModelParam& model_param, pooled_recurrent_states_ = {{max_batch_size, num_linear_layers, num_v_heads, key_head_dim, value_head_dim}, model_param.linear_state_dtype, kDEVICE}; + linear_prefix_conv_state_shape = {num_linear_layers, d_conv, conv_dim}; + linear_prefix_recurrent_state_shape = {num_linear_layers, num_v_heads, key_head_dim, value_head_dim}; + linear_prefix_conv_state_dtype = model_param.data_type; + linear_prefix_recurrent_state_dtype = model_param.linear_state_dtype; free_linear_state_slots_.reserve(max_batch_size); for (int slot = max_batch_size - 1; slot >= 0; --slot) { @@ -92,6 +172,7 @@ SequenceManager::SequenceManager(const ModelParam& model_param, (conv_one.byte_size() + recurrent_one.byte_size()) * mb); TM_LOG_INFO("[SeqMgr] linear-state combined total: {:.2f} MB", (pooled_conv_states_.byte_size() + pooled_recurrent_states_.byte_size()) * mb); + live_linear_bytes = pooled_conv_states_.byte_size() + pooled_recurrent_states_.byte_size(); } const int dbits = byte_size(runtime_dtype, 8); @@ -110,31 +191,68 @@ SequenceManager::SequenceManager(const ModelParam& model_param, block::Layout layout{block_config}; // dump(layout); - size_t block_size = layout.block_size(cache_layer_num); + size_t block_size = layout.block_size(cache_layer_num); + const bool has_linear_prefix_checkpoints = enable_prefix_caching && num_linear_layers > 0; + const size_t linear_prefix_slot_bytes = + has_linear_prefix_checkpoints ? + hybrid_prefix_budget::GetHybridCheckpointSlotBytes(linear_prefix_conv_state_shape, + linear_prefix_conv_state_dtype, + linear_prefix_recurrent_state_shape, + linear_prefix_recurrent_state_dtype) : + 0; if (num_linear_layers > 0 && block_count < 1.) { - const size_t linear_bytes = pooled_conv_states_.byte_size() + pooled_recurrent_states_.byte_size(); const size_t target_bytes = static_cast(free_before * block_count); TM_LOG_INFO("[SeqMgr] Adjusting block_count: free_before {:.2f} MB, linear {:.2f} MB, target {:.2f} MB", free_before / (1024. * 1024.), - linear_bytes / (1024. * 1024.), + live_linear_bytes / (1024. * 1024.), target_bytes / (1024. * 1024.)); - if (target_bytes <= linear_bytes) { + if (target_bytes <= live_linear_bytes) { TM_LOG_ERROR("[SeqMgr] Linear-state memory ({:.2f} MB) >= cache budget ({:.2f} MB). ", - linear_bytes / (1024. * 1024.), + live_linear_bytes / (1024. * 1024.), target_bytes / (1024. * 1024.)); TM_CHECK(0) << "Please decrease max_batch_size to reduce total linear state size or increase cache_max_entry_count."; } - const size_t cache_bytes = target_bytes - linear_bytes; - block_count = static_cast(cache_bytes) / static_cast(block_size); + const size_t cache_bytes = target_bytes - live_linear_bytes; + block_count = + has_linear_prefix_checkpoints ? static_cast(hybrid_prefix_budget::SolveHybridKvBlockCount( + cache_bytes, block_size, linear_prefix_slot_bytes, linear_prefix_cache_interval_blocks_)) : + static_cast(cache_bytes) / static_cast(block_size); TM_LOG_INFO("[SeqMgr] Adjusted block_count to {:.0f}", block_count); } + else if (num_linear_layers > 0 && block_count >= 1.) { + const size_t requested_blocks = static_cast(block_count); + const auto requested_cache_bytes = + has_linear_prefix_checkpoints ? hybrid_prefix_budget::GetHybridCacheBytesForBlocks( + requested_blocks, block_size, linear_prefix_slot_bytes, linear_prefix_cache_interval_blocks_) : + static_cast(requested_blocks) * block_size; + const size_t available_after_live = get_free_size(); + TM_CHECK(requested_cache_bytes <= static_cast(available_after_live)) + << "Insufficient memory for " + << (has_linear_prefix_checkpoints ? "hybrid prefix cache blocks and checkpoints." : "KV cache blocks."); + } block_manager_ = std::make_shared(block_size, block_count, chunk_size, allocator, get_free_size); if (enable_prefix_caching) { - block_trie_ = std::make_shared(block_config.block_len_, block_manager_); + int trie_linear_interval_blocks = 0; + int linear_state_slot_capacity = 0; + if (num_linear_layers > 0) { + const int checkpoint_slots = block_manager_->max_block_count() / linear_prefix_cache_interval_blocks_; + if (checkpoint_slots > 0) { + trie_linear_interval_blocks = linear_prefix_cache_interval_blocks_; + linear_state_slot_capacity = checkpoint_slots; + } + } + block_trie_ = std::make_shared(block_config.block_len_, + block_manager_, + trie_linear_interval_blocks, + linear_state_slot_capacity, + linear_prefix_conv_state_shape, + linear_prefix_conv_state_dtype, + linear_prefix_recurrent_state_shape, + linear_prefix_recurrent_state_dtype); } TM_LOG_WARN("prefix caching is {}", enable_prefix_caching ? "enabled" : "disabled"); } @@ -236,9 +354,13 @@ void SequenceManager::ReleaseLinearStateSlot(const Sequence& sequence) free_linear_state_slots_.push_back(slot_it->second); seq_to_linear_state_slot_.erase(slot_it); } - seq.conv_states = {}; - seq.recurrent_states = {}; - seq.linear_states_need_reset = false; + seq.conv_states = {}; + seq.recurrent_states = {}; + seq.linear_states_need_reset = false; + seq.pending_linear_prefix_conv_states = {}; + seq.pending_linear_prefix_recurrent_states = {}; + seq.pending_linear_prefix_capture_count = 0; + seq.pending_linear_prefix_capture_base_len = 0; } void SequenceManager::InvalidateStatesAndCache(const Sequence& sequence) @@ -256,8 +378,12 @@ void SequenceManager::InvalidateStatesAndCache(const Sequence& sequence, BlockId seq.blocks.clear(); seq.block_unique_ids.clear(); - seq.input_length = 0; - seq.cache_len = 0; + seq.input_length = 0; + seq.cache_len = 0; + seq.pending_linear_prefix_conv_states = {}; + seq.pending_linear_prefix_recurrent_states = {}; + seq.pending_linear_prefix_capture_count = 0; + seq.pending_linear_prefix_capture_base_len = 0; ReleaseLinearStateSlot(seq); } @@ -268,7 +394,8 @@ void SequenceManager::CachePrompt(const Sequences& sequences, int active_size) } for (int i = 0; i < active_size; ++i) { - if (auto& seq = *sequences[i]; !seq.prompt.empty()) { + auto& seq = *sequences[i]; + if (!seq.prompt.empty()) { const auto& [block_ids, unique_ids] = block_trie_->Cache(seq, seq.prompt); if (rank_ == 0) { // clang-format off @@ -282,6 +409,10 @@ void SequenceManager::CachePrompt(const Sequences& sequences, int active_size) seq.prompt.clear(); } } + seq.pending_linear_prefix_conv_states = {}; + seq.pending_linear_prefix_recurrent_states = {}; + seq.pending_linear_prefix_capture_count = 0; + seq.pending_linear_prefix_capture_base_len = 0; } } @@ -335,12 +466,14 @@ void SequenceManager::VerifyAndLockCached(const Sequences& sequences) } if (!freed_blocks.empty()) { block_manager_->Free(freed_blocks); + VerifyPrefixCache(); } block_manager_->Lock(valid_blocks); } void SequenceManager::CommitUnlockAndFree() { + bool prefix_cache_changed = false; if (!unlocked_.empty()) { block_manager_->Unlock(unlocked_); unlocked_.clear(); @@ -349,6 +482,18 @@ void SequenceManager::CommitUnlockAndFree() if (!freed_.empty()) { block_manager_->Free(freed_); freed_.clear(); + prefix_cache_changed = true; + } + + if (prefix_cache_changed) { + VerifyPrefixCache(); + } +} + +void SequenceManager::VerifyPrefixCache() +{ + if (block_trie_) { + block_trie_->Verify(); } } @@ -591,18 +736,40 @@ void SequenceManager::PrefixMatch(Sequences& sequences, const std::vector& continue; } - const auto& [block_ids, unique_ids] = block_trie_->Match(seq); + auto match = block_trie_->Match(seq); + auto block_ids = std::move(match.block_ids); + auto unique_ids = std::move(match.unique_ids); + if (pooled_recurrent_states_) { + const int matched_blocks = match.linear_cache_len / block_seq_len_; + block_ids.resize(matched_blocks); + unique_ids.resize(matched_blocks); + } if (rank_ == 0) { // clang-format off - TM_LOG_INFO("ID {}, hit blocks {}, cache_len {}", seq.id, (int)block_ids.size(), seq.cache_len); - TM_LOG_DEBUG("ID {}, hit block_ids {}, unique_ids {}", seq.id, - vector2string(block_ids), vector2string(unique_ids)); + TM_LOG_INFO("[SeqMgr][match] ID {}, hit blocks {}, linear_cache_len {}, cache_len {}", + seq.id, + (int)block_ids.size(), + match.linear_cache_len, + seq.cache_len); + TM_LOG_DEBUG("[SeqMgr][match] ID {}, hit block_ids {}, unique_ids {}", + seq.id, + vector2string(block_ids), + vector2string(unique_ids)); // clang-format on } /// TODO: `Unlock` and `Lock` can't be batched because there may be repeated blocks between sequences - if (const int offset = seq.cache_len / block_seq_len_; offset < block_ids.size()) { + const int desired_blocks = block_ids.size(); + if ((int)seq.blocks.size() > desired_blocks) { + if (BlockIds tail{seq.blocks.begin() + desired_blocks, seq.blocks.end()}; !tail.empty()) { + block_manager_->Unlock(tail); + seq.blocks.resize(desired_blocks); + seq.block_unique_ids.resize(desired_blocks); + } + } + + if (const int offset = seq.cache_len / block_seq_len_; offset < desired_blocks) { if (BlockIds tail{seq.blocks.begin() + offset, seq.blocks.end()}; !tail.empty()) { block_manager_->Unlock(tail); seq.blocks.resize(offset); @@ -610,9 +777,17 @@ void SequenceManager::PrefixMatch(Sequences& sequences, const std::vector& } seq.blocks.insert(seq.blocks.end(), block_ids.begin() + offset, block_ids.end()); seq.block_unique_ids.insert(seq.block_unique_ids.end(), unique_ids.begin() + offset, unique_ids.end()); - seq.cache_len = seq.blocks.size() * block_seq_len_; block_manager_->Lock({block_ids.begin() + offset, block_ids.end()}); } + seq.cache_len = desired_blocks * block_seq_len_; + + if (pooled_recurrent_states_ && match.linear_cache_len > 0) { + TM_CHECK(seq.conv_states && seq.recurrent_states) + << "Linear-attention state slot is not bound for sequence " << seq.id; + Copy(match.conv_states, seq.conv_states); + Copy(match.recurrent_states, seq.recurrent_states); + seq.linear_states_need_reset = false; + } if (rank_ == 0) { // clang-format off @@ -704,6 +879,7 @@ auto SequenceManager::Materialize(Sequences sequences, // evict cached blocks -> free if (schedule.evict) { block_manager_->Evict(schedule.evict); + VerifyPrefixCache(); } // allocate & assign blocks @@ -727,10 +903,6 @@ auto SequenceManager::Materialize(Sequences sequences, // block_manager_->active_count(), // block_manager_->cached_count(), // block_manager_->free_count()); - if (block_trie_) { - block_trie_->Verify(); - } - return outcome; } diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h index fff2706379..bef91cadac 100644 --- a/src/turbomind/models/llama/SequenceManager.h +++ b/src/turbomind/models/llama/SequenceManager.h @@ -55,6 +55,10 @@ struct Sequence { mutable Tensor conv_states; mutable Tensor recurrent_states; mutable bool linear_states_need_reset = false; + mutable Tensor pending_linear_prefix_conv_states; + mutable Tensor pending_linear_prefix_recurrent_states; + mutable int pending_linear_prefix_capture_count = 0; + mutable int pending_linear_prefix_capture_base_len = 0; explicit Sequence(uint64_t _id): id(_id) {} @@ -98,6 +102,7 @@ class SequenceManager { double block_count, int chunk_size, bool enable_prefix_caching, + int linear_prefix_cache_interval_blocks, int rank, int attn_cp_size, core::Allocator allocator, @@ -199,6 +204,7 @@ class SequenceManager { void Erase(std::map::iterator& it); void CommitUnlockAndFree(); + void VerifyPrefixCache(); void InvalidateStatesAndCache(const Sequence& seq, BlockIds& freed_blocks); @@ -218,6 +224,8 @@ class SequenceManager { int block_seq_len_; int rank_; int attn_cp_size_; + int linear_prefix_cache_interval_blocks_{}; + int linear_prefix_cache_interval_tokens_{}; // Use `std::map` to avoid reference invalidation std::map sequences_; diff --git a/src/turbomind/models/llama/gated_delta_net_kernels.cu b/src/turbomind/models/llama/gated_delta_net_kernels.cu index 8d6ff4f4d3..48f1bcbf08 100644 --- a/src/turbomind/models/llama/gated_delta_net_kernels.cu +++ b/src/turbomind/models/llama/gated_delta_net_kernels.cu @@ -561,10 +561,14 @@ __global__ void chunked_gated_delta_rule_kernel(T* v_out, const T* g_in, S* const* state_ptrs, const int* q_offsets, + const int* k_offsets, int num_v_heads, int num_k_heads, int k_dim_total, - int state_layer_offset) + int state_layer_offset, + S* const* capture_state_ptrs, + int capture_stride, + int checkpoint_interval_tokens) { constexpr int C = kChunkSize; constexpr int D = kHeadDim; @@ -588,13 +592,14 @@ __global__ void chunked_gated_delta_rule_kernel(T* v_out, const float scale = rsqrtf((float)D); // ── State tiling (same as v2) ── - constexpr int tile_k = 8; - constexpr int tile_v = 8; - constexpr int k_tiles = D / tile_k; // 16 - constexpr int k_threads = k_tiles; // 16 - constexpr int v_threads = kBlockDim / k_threads; // 16 - constexpr int v_tiles = D / tile_v; // 16 - constexpr int v_iters = cdiv(v_tiles, v_threads); // 1 + constexpr int tile_k = 8; + constexpr int tile_v = 8; + constexpr int k_tiles = D / tile_k; // 16 + constexpr int k_threads = k_tiles; // 16 + constexpr int v_threads = kBlockDim / k_threads; // 16 + constexpr int v_tiles = D / tile_v; // 16 + constexpr int v_iters = cdiv(v_tiles, v_threads); // 1 + constexpr int kAccessCState = sizeof(uint4) / sizeof(S); const int offset_k = threadIdx.x % k_threads; const int offset_v = threadIdx.x / k_threads; @@ -671,7 +676,9 @@ __global__ void chunked_gated_delta_rule_kernel(T* v_out, const int load_tok = threadIdx.x / kThreadsPerTok; // which token (0..C-1) const int load_lane = threadIdx.x % kThreadsPerTok; // lane within token's warp - const int num_chunks = (seq_len + C - 1) / C; + const int num_chunks = (seq_len + C - 1) / C; + const int history_len = k_offsets ? ((__ldg(&k_offsets[b + 1]) - __ldg(&k_offsets[b])) - seq_len) : 0; + S* const capture_base = capture_state_ptrs ? capture_state_ptrs[b] : nullptr; for (int ci = 0; ci < num_chunks; ++ci) { const int chunk_start = tok_off + ci * C; @@ -795,6 +802,28 @@ __global__ void chunked_gated_delta_rule_kernel(T* v_out, if (offset_k == 0) Store(&v_out[gt * v_dim + h * D + v_base], vec_O); } + + if (capture_base && checkpoint_interval_tokens > 0) { + const int processed = history_len + ci * C + t + 1; + if (processed % checkpoint_interval_tokens == 0) { + const int capture_idx = + processed / checkpoint_interval_tokens - history_len / checkpoint_interval_tokens - 1; + S* capture_ptr = capture_base + capture_idx * capture_stride + state_layer_offset + h * state_size; + PRAGMA_UNROLL + for (int vi = 0; vi < v_iters; ++vi) { + PRAGMA_UNROLL + for (int k = 0; k < tile_k; ++k) { + PRAGMA_UNROLL + for (int c = 0; c < tile_v / kAccessCState; ++c) { + auto tmp = cast((Array&)vec_S[vi][k][c * kAccessCState]); + Store(capture_ptr + (offset_k * tile_k + k) * D + (offset_v + vi * v_threads) * tile_v + + c * kAccessCState, + tmp); + } + } + } + } + } } __syncthreads(); // [sync 2] ensure all reads done before next chunk overwrites smem } // chunk loop @@ -852,9 +881,46 @@ void invokeChunkedGatedDeltaRuleBatched(Ref v_out_, int num_k_heads, int state_layer_offset, DataType state_dtype, + int sm_count, + int* work_counter, + cudaStream_t stream) +{ + invokeChunkedGatedDeltaRuleBatched(v_out_, + qkv_in, + beta, + g, + state_ptrs, + q_offsets, + Buffer_{}, + batch_size, + num_k_heads, + state_layer_offset, + state_dtype, + sm_count, + work_counter, + stream, + Buffer_{}, + 0, + 0); +} + +void invokeChunkedGatedDeltaRuleBatched(Ref v_out_, + const Tensor& qkv_in, + const Tensor& beta, + const Tensor& g, + const Buffer_& state_ptrs, + const Buffer_& q_offsets, + const Buffer_& k_offsets, + int batch_size, + int num_k_heads, + int state_layer_offset, + DataType state_dtype, int /*sm_count*/, int* /*work_counter*/, - cudaStream_t stream) + cudaStream_t stream, + const Buffer_& capture_state_ptrs, + int capture_stride, + int checkpoint_interval_tokens) { auto& v_out = v_out_.get(); @@ -895,16 +961,21 @@ void invokeChunkedGatedDeltaRuleBatched(Ref v_out_, cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); } - kernel<<>>(v_out.data(), - qkv_in.data(), - beta.data(), - g.data(), - (S* const*)state_ptrs.data(), - q_offsets.data(), - num_v_heads, - num_k_heads, - k_dim_total, - state_layer_offset); + kernel<<>>( + v_out.data(), + qkv_in.data(), + beta.data(), + g.data(), + (S* const*)state_ptrs.data(), + q_offsets.data(), + k_offsets.data_or((const int*)nullptr), + num_v_heads, + num_k_heads, + k_dim_total, + state_layer_offset, + reinterpret_cast(capture_state_ptrs.data_or((void* const*)nullptr)), + capture_stride, + checkpoint_interval_tokens); }; if (state_dtype == kFloat32) { launch(float{}); @@ -1077,7 +1148,10 @@ __global__ void __launch_bounds__(BLOCK_DIM) fused_conv1d_batched_kernel_v2(T* int num_token_tiles, int state_layer_offset, int total_work, - int num_ch_tiles) + int num_ch_tiles, + void* const* capture_state_ptrs, + int capture_stride, + int checkpoint_interval_tokens) { static_assert(BLOCK_DIM * CHANNELS_PER_THREAD > 0); @@ -1171,8 +1245,9 @@ __global__ void __launch_bounds__(BLOCK_DIM) fused_conv1d_batched_kernel_v2(T* n_tokens = min(NUM_TOKENS, seq_len - t_local_start); } - const int ring_start = (history_len + t_local_start + 1) % D_CONV; - T* state_base = (T*)conv_state_ptrs[b] + state_layer_offset; + const int ring_start = (history_len + t_local_start + 1) % D_CONV; + T* state_base = (T*)conv_state_ptrs[b] + state_layer_offset; + T* capture_base = capture_state_ptrs ? (T*)capture_state_ptrs[b] : nullptr; if (ch_active) { constexpr int VALS_SIZE = NUM_TOKENS + D_CONV - 1; @@ -1214,6 +1289,24 @@ __global__ void __launch_bounds__(BLOCK_DIM) fused_conv1d_batched_kernel_v2(T* } Store(out + (seq_off + t_local_start + tok) * conv_dim + c_base, out_vals); + + if (capture_base && checkpoint_interval_tokens > 0) { + const int processed = history_len + t_local_start + tok + 1; + if (processed % checkpoint_interval_tokens == 0) { + const int capture_idx = + processed / checkpoint_interval_tokens - history_len / checkpoint_interval_tokens - 1; + T* capture_ptr = capture_base + capture_idx * capture_stride + state_layer_offset; + PRAGMA_UNROLL + for (int d = 0; d < D_CONV; ++d) { + const int pos = t_local_start - (D_CONV - 1) + tok + d; + int ring_d = (history_len + pos + 1) % D_CONV; + if (ring_d < 0) { + ring_d += D_CONV; + } + Store(capture_ptr + ring_d * conv_dim + c_base, vals[tok + d]); + } + } + } } } @@ -1243,6 +1336,39 @@ void invokeFusedConv1dSiLU(Ref out_, int sm_count, int* work_counter, cudaStream_t stream) +{ + invokeFusedConv1dSiLU(out_, + in, + weight, + bias, + conv_state_ptrs, + q_offsets, + k_offsets, + batch_size, + state_layer_offset, + sm_count, + work_counter, + stream, + Buffer_{}, + 0, + 0); +} + +void invokeFusedConv1dSiLU(Ref out_, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + const Buffer_& conv_state_ptrs, + const Buffer_& q_offsets, + const Buffer_& k_offsets, + int batch_size, + int state_layer_offset, + int sm_count, + int* work_counter, + cudaStream_t stream, + const Buffer_& capture_state_ptrs, + int capture_stride, + int checkpoint_interval_tokens) { auto& out = out_.get(); @@ -1287,7 +1413,10 @@ void invokeFusedConv1dSiLU(Ref out_, num_token_tiles, state_layer_offset, total_work, - num_ch_tiles); + num_ch_tiles, + capture_state_ptrs.data_or((void* const*)nullptr), + capture_stride, + checkpoint_interval_tokens); }; int avg_seq = total_tokens / batch_size; diff --git a/src/turbomind/models/llama/gated_delta_net_kernels.h b/src/turbomind/models/llama/gated_delta_net_kernels.h index 9519db4da7..cf8412a85b 100644 --- a/src/turbomind/models/llama/gated_delta_net_kernels.h +++ b/src/turbomind/models/llama/gated_delta_net_kernels.h @@ -33,6 +33,22 @@ void invokeFusedConv1dSiLU(Ref out, int* work_counter, cudaStream_t stream); +void invokeFusedConv1dSiLU(Ref out, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + const Buffer_& conv_state_ptrs, + const Buffer_& q_offsets, + const Buffer_& k_offsets, + int batch_size, + int state_layer_offset, + int sm_count, + int* work_counter, + cudaStream_t stream, + const Buffer_& capture_state_ptrs, + int capture_stride, + int checkpoint_interval_tokens); + // All three recurrent-rule launchers share the same trailing parameters for // interface consistency: // sm_count — multiprocessor count, queried once by the caller at init @@ -96,6 +112,24 @@ void invokeChunkedGatedDeltaRuleBatched(Ref v_out, int* work_counter, cudaStream_t stream); +void invokeChunkedGatedDeltaRuleBatched(Ref v_out, + const Tensor& qkv_in, + const Tensor& beta, + const Tensor& g, + const Buffer_& state_ptrs, + const Buffer_& q_offsets, + const Buffer_& k_offsets, + int batch_size, + int num_k_heads, + int state_layer_offset, + DataType state_dtype, + int sm_count, + int* work_counter, + cudaStream_t stream, + const Buffer_& capture_state_ptrs, + int capture_stride, + int checkpoint_interval_tokens); + // ============================================================================= // Helper kernels // ============================================================================= diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index b61e5b1fe8..cb9d125da8 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -133,6 +133,7 @@ struct EngineParam { float cache_max_block_count; int cache_chunk_size; bool enable_prefix_caching; + int linear_prefix_cache_interval_blocks; bool enable_metrics; // chunking params diff --git a/src/turbomind/turbomind.cc b/src/turbomind/turbomind.cc index 1529269fde..ce5c171dd7 100644 --- a/src/turbomind/turbomind.cc +++ b/src/turbomind/turbomind.cc @@ -439,14 +439,12 @@ TurboMind::Impl::Impl(string model_dir, string config, FFICtxFactory ffi_ctx_fac engine_param_.max_context_token_num = engine["max_context_token_num"].as(0); engine_param_.session_len = model["session_len"].as(0); - engine_param_.cache_max_block_count = engine["cache_max_entry_count"].as(0); - engine_param_.cache_chunk_size = engine["cache_chunk_size"].as(0); - engine_param_.enable_prefix_caching = engine["enable_prefix_caching"].as(false); - engine_param_.enable_metrics = engine["enable_metrics"].as(false); - - if (engine_param_.enable_prefix_caching && HasLinearAttention(model_param_)) { - TM_CHECK(0) << "Prefix caching is unsupported when linear attention is present"; - } + engine_param_.cache_max_block_count = engine["cache_max_entry_count"].as(0); + engine_param_.cache_chunk_size = engine["cache_chunk_size"].as(0); + engine_param_.enable_prefix_caching = engine["enable_prefix_caching"].as(false); + engine_param_.linear_prefix_cache_interval_blocks = engine["linear_prefix_cache_interval_blocks"].as(64); + engine_param_.enable_metrics = engine["enable_metrics"].as(false); + TM_CHECK_GE(engine_param_.linear_prefix_cache_interval_blocks, 1); engine_param_.num_tokens_per_iter = engine["num_tokens_per_iter"].as(0); engine_param_.max_prefill_iters = engine["max_prefill_iters"].as(1); diff --git a/tests/test_lmdeploy/test_turbomind/test_api_server.py b/tests/test_lmdeploy/test_turbomind/test_api_server.py new file mode 100644 index 0000000000..2c80f93d19 --- /dev/null +++ b/tests/test_lmdeploy/test_turbomind/test_api_server.py @@ -0,0 +1,113 @@ +import sys +from types import ModuleType, SimpleNamespace + +from lmdeploy.cli import serve as serve_module +from lmdeploy.messages import TurbomindEngineConfig + + +def _make_api_server_args(**overrides): + args = dict(model_path='QuantTrio/Qwen3.5-27B-AWQ', + model_name='qwen35-awq', + backend='turbomind', + dtype='auto', + tp=2, + dp=1, + ep=1, + cp=1, + nnodes=1, + node_rank=0, + dist_init_addr=None, + max_batch_size=None, + session_len=4096, + model_format='awq', + quant_policy=8, + rope_scaling_factor=0.0, + cache_max_entry_count=0.8, + cache_block_seq_len=64, + enable_prefix_caching=True, + linear_prefix_cache_interval_blocks=4, + max_prefill_token_num=8192, + num_tokens_per_iter=0, + max_prefill_iters=1, + async_=1, + communicator='nccl', + disable_metrics=False, + adapters=None, + device='cuda', + eager_mode=False, + disable_vision_encoder=False, + logprobs_mode='raw_logits', + dllm_block_length=64, + dllm_unmasking_strategy='low_confidence_dynamic', + dllm_denoising_steps=0, + dllm_confidence_threshold=0.0, + enable_return_routed_experts=False, + distributed_executor_backend=None, + chat_template=None, + vision_max_batch_size=1, + server_name='127.0.0.1', + server_port=23333, + allow_origins=['*'], + allow_credentials=False, + allow_methods=['*'], + allow_headers=['*'], + allow_terminate_by_client=False, + enable_abort_handling=False, + log_level='info', + api_keys=None, + ssl=None, + proxy_url=None, + max_log_len=None, + disable_fastapi_docs=False, + max_concurrent_requests=None, + reasoning_parser='qwen-qwq', + tool_call_parser='qwen3coder', + hf_overrides=None) + args.update(overrides) + return SimpleNamespace(**args) + + +def test_api_server_turbomind_forwards_hybrid_prefix_cache_options(monkeypatch): + captured = {} + fake_api_server = ModuleType('lmdeploy.serve.openai.api_server') + + def fake_serve(model_path, **kwargs): + captured['model_path'] = model_path + captured.update(kwargs) + + fake_api_server.serve = fake_serve + + monkeypatch.setitem(sys.modules, 'lmdeploy.serve.openai.api_server', fake_api_server) + monkeypatch.setattr('lmdeploy.archs.autoget_backend', lambda _: 'turbomind') + monkeypatch.setattr(serve_module, 'get_max_batch_size', lambda device: 13) + monkeypatch.setattr(serve_module, 'get_chat_template', lambda *_: None) + monkeypatch.setattr(serve_module, 'get_speculative_config', lambda _: None) + + serve_module.SubCliServe.api_server(_make_api_server_args()) + + assert captured['backend'] == 'turbomind' + assert captured['model_path'] == 'QuantTrio/Qwen3.5-27B-AWQ' + assert isinstance(captured['backend_config'], TurbomindEngineConfig) + assert captured['backend_config'].enable_prefix_caching is True + assert captured['backend_config'].linear_prefix_cache_interval_blocks == 4 + + +def test_api_server_turbomind_uses_default_cuda_batch_size(monkeypatch): + captured = {} + fake_api_server = ModuleType('lmdeploy.serve.openai.api_server') + + def fake_serve(model_path, **kwargs): + captured['model_path'] = model_path + captured.update(kwargs) + + fake_api_server.serve = fake_serve + + monkeypatch.setitem(sys.modules, 'lmdeploy.serve.openai.api_server', fake_api_server) + monkeypatch.setattr('lmdeploy.archs.autoget_backend', lambda _: 'turbomind') + monkeypatch.setattr(serve_module, 'get_max_batch_size', lambda device: 7) + monkeypatch.setattr(serve_module, 'get_chat_template', lambda *_: None) + monkeypatch.setattr(serve_module, 'get_speculative_config', lambda _: None) + + serve_module.SubCliServe.api_server(_make_api_server_args(max_batch_size=None)) + + assert captured['backend_config'].max_batch_size == 7 diff --git a/tests/test_lmdeploy/test_turbomind/test_engine_config.py b/tests/test_lmdeploy/test_turbomind/test_engine_config.py new file mode 100644 index 0000000000..8747009bb4 --- /dev/null +++ b/tests/test_lmdeploy/test_turbomind/test_engine_config.py @@ -0,0 +1,19 @@ +import pytest +from pydantic import ValidationError + +from lmdeploy import TurbomindEngineConfig + + +def test_linear_prefix_cache_interval_blocks_default(): + config = TurbomindEngineConfig(enable_prefix_caching=True) + assert config.linear_prefix_cache_interval_blocks == 64 + + +def test_linear_prefix_cache_interval_blocks_validation(): + with pytest.raises(ValidationError, match='invalid linear_prefix_cache_interval_blocks'): + TurbomindEngineConfig(linear_prefix_cache_interval_blocks=0) + + +def test_linear_prefix_cache_interval_blocks_override(): + config = TurbomindEngineConfig(enable_prefix_caching=True, linear_prefix_cache_interval_blocks=4) + assert config.linear_prefix_cache_interval_blocks == 4