Skip to content

Commit 2d8338f

Browse files
authored
[Optimization][DeepSeekV3.2]Reducing slot_mapping compute frequency from twice per layer to a single pre-processing step. (#7367)
1 parent d2d633b commit 2d8338f

10 files changed

Lines changed: 73 additions & 146 deletions

File tree

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -540,12 +540,10 @@ std::vector<paddle::Tensor> count_tokens_per_expert_func(
540540
const paddle::Tensor& topk_ids,
541541
int64_t num_experts,
542542
bool compute_padded_cumsum = false);
543-
void GetPositionIdsAndMaskEncoderBatch(
544-
const paddle::Tensor& seq_lens_encoder,
545-
const paddle::Tensor& seq_lens_decoder,
546-
const paddle::Tensor& seq_lens_this_time,
547-
const paddle::Tensor& position_ids,
548-
const paddle::Tensor& mask_encoder_batch);
543+
void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder,
544+
const paddle::Tensor& seq_lens_decoder,
545+
const paddle::Tensor& seq_lens_this_time,
546+
const paddle::Tensor& position_ids);
549547

550548
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
551549
const paddle::Tensor& kv_nope,

custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel(
2020
const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度
2121
const int* seq_lens_this_time,
2222
int* position_ids, // 输出的一维 position_ids
23-
int* mask_encoder_batch,
24-
const int bsz) { // 批次大小
23+
const int bsz) { // 批次大小
2524
// 当前线程索引(每个线程对应一个批次)
2625
int tid = threadIdx.x;
2726
if (tid >= bsz) return;
@@ -43,43 +42,38 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel(
4342
// 写入 encoder 的 position_ids
4443
for (int i = 0; i < encoder_len; i++) {
4544
position_ids[offset + i] = i;
46-
mask_encoder_batch[offset + i] = 1;
4745
}
4846
offset += encoder_len;
4947

5048
// 写入 decoder 的 position_ids
5149
if (decoder_len > 0) {
5250
for (int i = 0; i < seq_len_this_time; i++) {
5351
position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身
54-
mask_encoder_batch[offset + i] = 0;
5552
}
5653
}
5754
}
5855

59-
void GetPositionIdsAndMaskEncoderBatch(
60-
const paddle::Tensor& seq_lens_encoder,
61-
const paddle::Tensor& seq_lens_decoder,
62-
const paddle::Tensor& seq_lens_this_time,
63-
const paddle::Tensor& position_ids,
64-
const paddle::Tensor& mask_encoder_batch) {
56+
void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder,
57+
const paddle::Tensor& seq_lens_decoder,
58+
const paddle::Tensor& seq_lens_this_time,
59+
const paddle::Tensor& position_ids) {
6560
const int bsz = seq_lens_this_time.shape()[0];
6661

6762
GetPositionIdsAndMaskEncoderBatchKernel<<<1, bsz, 0, position_ids.stream()>>>(
6863
seq_lens_encoder.data<int>(),
6964
seq_lens_decoder.data<int>(),
7065
seq_lens_this_time.data<int>(),
7166
const_cast<int*>(position_ids.data<int>()),
72-
const_cast<int*>(mask_encoder_batch.data<int>()),
7367
bsz);
7468
}
7569

7670
PD_BUILD_STATIC_OP(get_position_ids_and_mask_encoder_batch)
77-
.Inputs({"seq_lens_encoder",
78-
"seq_lens_decoder",
79-
"seq_lens_this_time",
80-
"position_ids",
81-
"mask_encoder_batch"})
82-
.Outputs({"position_ids_out", "mask_encoder_batch_out"})
83-
.SetInplaceMap({{"position_ids", "position_ids_out"},
84-
{"mask_encoder_batch", "mask_encoder_batch_out"}})
71+
.Inputs({
72+
"seq_lens_encoder",
73+
"seq_lens_decoder",
74+
"seq_lens_this_time",
75+
"position_ids",
76+
})
77+
.Outputs({"position_ids_out"})
78+
.SetInplaceMap({{"position_ids", "position_ids_out"}})
8579
.SetKernelFn(PD_KERNEL(GetPositionIdsAndMaskEncoderBatch));

fastdeploy/model_executor/forward_meta.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ class ForwardMeta:
160160

161161
# for mla & dsa
162162
position_ids: Optional[paddle.Tensor] = None
163-
mask_encoder_batch: Optional[paddle.Tensor] = None
163+
# for kvcache slot
164+
slot_mapping: Optional[paddle.Tensor] = None
164165

165166
real_bsz: int = 0
166167

fastdeploy/model_executor/layers/attention/dsa_attention_backend.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -54,33 +54,6 @@ def yarn_get_mscale(scale=1, mscale=1):
5454
return 0.1 * mscale * math.log(scale) + 1.0
5555

5656

57-
def compute_slot_mapping(
58-
block_tables: paddle.Tensor, # [num_reqs, max_blocks_per_req]
59-
positions: paddle.Tensor, # [num_tokens] 每个token的位置
60-
batch_id_per_token: paddle.Tensor, # [num_tokens] 每个token属于哪个请求
61-
block_size: int,
62-
) -> paddle.Tensor:
63-
"""
64-
计算 slot_mapping
65-
66-
公式: slot = block_id * block_size + offset_in_block
67-
"""
68-
# 1. 计算每个 token 对应的 block 索引
69-
block_idx = positions // block_size # [num_tokens]
70-
71-
# 2. 从 block_tables 中查表获取 block_id
72-
# block_tables[batch_id_per_token, block_idx]
73-
block_ids = block_tables[batch_id_per_token, block_idx] # [num_tokens]
74-
75-
# 3. 计算在 block 内的偏移
76-
block_offset = positions % block_size # [num_tokens]
77-
78-
# 4. 计算 slot_mapping
79-
slot_mapping = block_ids * block_size + block_offset
80-
81-
return slot_mapping.cast(paddle.int64)
82-
83-
8457
@dataclass
8558
class DSAAttentionMetadata(AttentionMetadata):
8659
"""
@@ -346,18 +319,11 @@ def forward_mixed(
346319

347320
scale = paddle.abs(compressed_kv).max() / 200.0
348321

349-
slot_mapping = compute_slot_mapping(
350-
forward_meta.block_tables,
351-
forward_meta.position_ids,
352-
forward_meta.batch_id_per_token,
353-
self.block_size,
354-
)
355-
356322
dsk_attn_write_cache(
357323
compressed_kv,
358324
k_pe,
359325
latent_cache,
360-
slot_mapping,
326+
forward_meta.slot_mapping,
361327
scale.cast(paddle.float32),
362328
"fp8_ds_mla",
363329
)

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 6 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
4747
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
4848
from fastdeploy.model_executor.layers.normalization import LayerNorm, RMSNorm
49+
from fastdeploy.model_executor.layers.quantization.fp8_utils import (
50+
per_token_group_quant_fp8,
51+
)
4952
from fastdeploy.model_executor.layers.rotary_embedding import (
5053
DeepseekScalingRotaryEmbedding,
5154
)
@@ -59,16 +62,6 @@
5962
)
6063
from fastdeploy.platforms import current_platform
6164

62-
if current_platform.is_cuda() or current_platform.is_maca():
63-
from fastdeploy.model_executor.ops.gpu import (
64-
get_position_ids_and_mask_encoder_batch,
65-
)
66-
67-
from fastdeploy.model_executor.layers.quantization.fp8_utils import (
68-
per_token_group_quant_fp8,
69-
)
70-
from fastdeploy.platforms import current_platform
71-
7265
if current_platform.is_cuda():
7366
from fastdeploy.model_executor.ops.gpu import (
7467
cp_gather_indexer_k_quant_cache,
@@ -471,33 +464,6 @@ def forward(
471464
return output
472465

473466

474-
def compute_slot_mapping(
475-
block_tables: paddle.Tensor, # [num_reqs, max_blocks_per_req]
476-
positions: paddle.Tensor, # [num_tokens] 每个token的位置
477-
batch_id_per_token: paddle.Tensor, # [num_tokens] 每个token属于哪个请求
478-
block_size: int,
479-
) -> paddle.Tensor:
480-
"""
481-
计算 slot_mapping
482-
483-
公式: slot = block_id * block_size + offset_in_block
484-
"""
485-
# 1. 计算每个 token 对应的 block 索引
486-
block_idx = positions // block_size # [num_tokens]
487-
488-
# 2. 从 block_tables 中查表获取 block_id
489-
# block_tables[batch_id_per_token, block_idx]
490-
block_ids = block_tables[batch_id_per_token, block_idx] # [num_tokens]
491-
492-
# 3. 计算在 block 内的偏移
493-
block_offset = positions % block_size # [num_tokens]
494-
495-
# 4. 计算 slot_mapping
496-
slot_mapping = block_ids * block_size + block_offset
497-
498-
return slot_mapping.cast(paddle.int64)
499-
500-
501467
import triton
502468
import triton.language as tl
503469

@@ -686,17 +652,12 @@ def forward(
686652
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.index_n_heads**-0.5
687653
weights = weights.squeeze(-1)
688654

689-
slot_mapping = compute_slot_mapping(
690-
forward_meta.block_tables,
691-
forward_meta.position_ids,
692-
forward_meta.batch_id_per_token,
693-
64,
694-
)
695-
696655
indexer_top_k = paddle.full([q_fp8.shape[0], self.index_topk], -1, dtype="int32")
697656

698657
# indexer write_cache
699-
indexer_k_quant_and_cache(k, self.indexer_cache, slot_mapping, self.quant_block_size, self.scale_fmt)
658+
indexer_k_quant_and_cache(
659+
k, self.indexer_cache, forward_meta.slot_mapping, self.quant_block_size, self.scale_fmt
660+
)
700661

701662
from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm
702663

@@ -1172,12 +1133,6 @@ def __init__(self, fd_config: FDConfig):
11721133
num_embeddings=fd_config.model_config.vocab_size,
11731134
prefix="lm_head",
11741135
)
1175-
self.position_ids_buffer = paddle.empty(
1176-
[fd_config.scheduler_config.max_num_batched_tokens], dtype=paddle.int32
1177-
)
1178-
self.mask_encoder_batch_buffer = paddle.empty(
1179-
[fd_config.scheduler_config.max_num_batched_tokens, 1], dtype=paddle.int32
1180-
)
11811136

11821137
@classmethod
11831138
def name(cls):
@@ -1274,25 +1229,6 @@ def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta
12741229
logits[:, self.ori_vocab_size :] = -float("inf")
12751230
return logits
12761231

1277-
def pre_process(self, forward_meta):
1278-
""" """
1279-
seq_lens_encoder = forward_meta.seq_lens_encoder
1280-
seq_lens_decoder = forward_meta.seq_lens_decoder
1281-
seq_lens_this_time = forward_meta.seq_lens_this_time
1282-
1283-
current_total_tokens = forward_meta.ids_remove_padding.shape[0]
1284-
position_ids = self.position_ids_buffer[:current_total_tokens]
1285-
mask_encoder_batch = self.mask_encoder_batch_buffer[:current_total_tokens]
1286-
1287-
get_position_ids_and_mask_encoder_batch(
1288-
seq_lens_encoder,
1289-
seq_lens_decoder,
1290-
seq_lens_this_time,
1291-
position_ids,
1292-
mask_encoder_batch,
1293-
)
1294-
return position_ids, mask_encoder_batch
1295-
12961232
def empty_input_forward(self, forward_meta):
12971233
"""
12981234
empty_input_forward
@@ -1313,7 +1249,6 @@ def forward(
13131249
forward_meta: ForwardMeta,
13141250
):
13151251
ids_remove_padding = inputs["ids_remove_padding"]
1316-
forward_meta.position_ids, forward_meta.mask_encoder_batch = self.pre_process(forward_meta)
13171252
hidden_states = self.model(
13181253
ids_remove_padding=ids_remove_padding,
13191254
forward_meta=forward_meta,

fastdeploy/worker/gpu_model_runner.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@
4545
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
4646
AttentionBackend,
4747
)
48+
from fastdeploy.model_executor.layers.attention.dsa_attention_backend import (
49+
DSAAttentionBackend,
50+
)
51+
from fastdeploy.model_executor.layers.attention.mla_attention_backend import (
52+
MLAAttentionBackend,
53+
)
4854
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
4955
RoutingReplayManager,
5056
)
@@ -79,6 +85,7 @@
7985
speculate_schedule_cache,
8086
set_data_ipc,
8187
unset_data_ipc,
88+
get_position_ids_and_mask_encoder_batch,
8289
)
8390

8491
import zmq
@@ -1267,6 +1274,33 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p
12671274
)
12681275
return token_num, token_num_event
12691276

1277+
def _compute_position_ids_and_slot_mapping(self) -> None:
1278+
"""Compute position_ids and slot_mapping for KV cache addressing.
1279+
This is a general computation based on sequence length info and block tables,
1280+
applicable to all models that need per-token KV cache physical slot addresses.
1281+
Results are stored in self.forward_meta.
1282+
"""
1283+
# NOTE(zhushengguang): Only support MLAAttentionBackend and DSAAttentionBackend currently.
1284+
if not isinstance(self.attn_backends[0], (MLAAttentionBackend, DSAAttentionBackend)):
1285+
return
1286+
current_total_tokens = self.forward_meta.ids_remove_padding.shape[0]
1287+
position_ids = self.share_inputs["position_ids_buffer"][:current_total_tokens]
1288+
get_position_ids_and_mask_encoder_batch(
1289+
self.forward_meta.seq_lens_encoder,
1290+
self.forward_meta.seq_lens_decoder,
1291+
self.forward_meta.seq_lens_this_time,
1292+
position_ids,
1293+
)
1294+
block_size = self.cache_config.block_size
1295+
block_idx = position_ids // block_size # [num_tokens]
1296+
assert self.forward_meta.batch_id_per_token.shape == block_idx.shape
1297+
block_ids = self.forward_meta.block_tables[self.forward_meta.batch_id_per_token, block_idx] # [num_tokens]
1298+
block_offset = position_ids % block_size # [num_tokens]
1299+
slot_mapping = self.share_inputs["slot_mapping_buffer"][:current_total_tokens]
1300+
paddle.assign((block_ids * block_size + block_offset).cast(paddle.int64), slot_mapping)
1301+
self.forward_meta.position_ids = position_ids
1302+
self.forward_meta.slot_mapping = slot_mapping
1303+
12701304
def _process_reorder(self) -> None:
12711305
if self.attn_backends and getattr(self.attn_backends[0], "enable_ids_reorder", False):
12721306
self.share_inputs.enable_pd_reorder = True
@@ -1860,6 +1894,8 @@ def _dummy_run(
18601894
# 2. Padding inputs for cuda graph
18611895
self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph
18621896
self.padding_cudagraph_inputs()
1897+
# Compute position_ids and slot_mapping
1898+
self._compute_position_ids_and_slot_mapping()
18631899

18641900
model_inputs = {}
18651901
model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"]
@@ -2197,6 +2233,8 @@ def _preprocess(
21972233

21982234
# Padding inputs for cuda graph
21992235
self.padding_cudagraph_inputs()
2236+
# Compute position_ids and slot_mapping
2237+
self._compute_position_ids_and_slot_mapping()
22002238

22012239
model_inputs = {}
22022240
model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"]

fastdeploy/worker/input_batch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,11 @@ def init_share_inputs(self):
188188
self.cu_seqlens_q = paddle.full([max_num_seqs + 1], 0, dtype="int32")
189189
self.cu_seqlens_k = paddle.full([max_num_seqs + 1], 0, dtype="int32")
190190

191+
# Initialize addressing buffers
192+
_max_batched_tokens = self.scheduler_config.max_num_batched_tokens
193+
self.position_ids_buffer = paddle.zeros([_max_batched_tokens], dtype=paddle.int32)
194+
self.slot_mapping_buffer = paddle.zeros([_max_batched_tokens], dtype=paddle.int64)
195+
191196
# Declare AttentionBackend buffers
192197
self.decoder_batch_ids = None
193198
self.decoder_tile_ids_per_batch = None

tests/distributed/chunked_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class SchedulerConfig:
8585
name = "default"
8686
splitwise_role = "mixed"
8787
max_num_seqs = 2
88+
max_num_batched_tokens = 2048
8889

8990
parallel_config = ParallelConfig()
9091
scheduler_config = SchedulerConfig()

0 commit comments

Comments
 (0)