Skip to content

Commit 48d2bbe

Browse files
authored
fix dsa (#7252)
1 parent b262419 commit 48d2bbe

3 files changed

Lines changed: 55 additions & 49 deletions

File tree

custom_ops/gpu_ops/sparse_indexer/indexer_topk.cuh

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,22 +2061,28 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
20612061
int batch_id, length;
20622062
const IdType* block_table_pre_batch;
20632063

2064+
IdType* dst;
2065+
20642066
if (seq_len_decoder != nullptr) { // decode
2065-
batch_id = batch_id_per_token[bid / q_num_heads];
2067+
// batch_id = batch_id_per_token[bid / q_num_heads];
2068+
batch_id = bid / q_num_heads;
20662069
if (batch_id == -1) return;
20672070
length = (seq_len_decoder[batch_id]); // for pack q k
20682071
if (length == 0) return;
20692072
if (block_tables != nullptr) {
20702073
block_table_pre_batch = block_tables + batch_id * max_block_num;
20712074
}
2075+
dst = output + aux_input[batch_id] * top_k;
2076+
20722077
} else { // prefill
20732078
// length = (lengths != nullptr) ? lengths[bid] : static_cast<int>(max_len);
20742079
length = (lengths != nullptr) ? lengths[bid / q_num_heads]
20752080
: static_cast<int>(max_len);
2081+
dst = output + bid * top_k;
20762082
}
20772083

20782084
const DType* score = input + bid * max_len;
2079-
IdType* dst = output + bid * top_k;
2085+
// IdType* dst = output + bid * top_k;
20802086

20812087
// Mode-specific setup
20822088
[[maybe_unused]] const IdType* src_page_entry = nullptr;
@@ -2110,8 +2116,8 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
21102116
? static_cast<IdType>(block_ids * 64 + block_offset)
21112117
: static_cast<IdType>(-1);
21122118
} else {
2113-
dst[i] =
2114-
(i < length) ? static_cast<IdType>(i) : static_cast<IdType>(-1);
2119+
dst[i] = (i < length) ? static_cast<IdType>(i) + offset_val
2120+
: static_cast<IdType>(-1);
21152121
}
21162122
} else { // Plain
21172123
if (i < length) {
@@ -2337,10 +2343,9 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
23372343
block_idx = idx / 64;
23382344
block_ids = block_table_pre_batch[block_idx];
23392345
block_offset = idx % 64;
2340-
dst[base] =
2341-
static_cast<IdType>(block_ids * 64 + block_offset); // + offset_val
2346+
dst[base] = static_cast<IdType>(block_ids * 64 + block_offset);
23422347
} else {
2343-
dst[base] = static_cast<IdType>(idx); //+ offset_val;
2348+
dst[base] = static_cast<IdType>(idx) + offset_val;
23442349
}
23452350

23462351
} else { // Plain

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -674,14 +674,13 @@ def forward(
674674
self.indexer_cache, k_fp8_cache, k_scale_cache, forward_meta.block_tables, forward_meta.cu_seqlens_k
675675
)
676676

677-
k_scale_cache = k_scale_cache.flatten()[: k.shape[0]]
678-
k_cache = k_fp8_cache.view(paddle.float8_e4m3fn), k_scale_cache
677+
k_scale_cache_real = k_scale_cache.flatten()[: k.shape[0]].contiguous()
678+
k_cache = k_fp8_cache.view(paddle.float8_e4m3fn), k_scale_cache_real
679679

680680
# TODO(changwenbin): Constructed using maskoffset
681681
# ks,ke = forward_meta.attn_mask_offsets[::2].contiguous(),forward_meta.attn_mask_offsets[1::2].contiguous()
682682
num_tokens = q_fp8.shape[0]
683683
ks = paddle.zeros(num_tokens, dtype=paddle.int32)
684-
ks_topk = paddle.zeros(num_tokens, dtype=paddle.int32)
685684
ke = paddle.zeros(num_tokens, dtype=paddle.int32)
686685

687686
bsz = forward_meta.seq_lens_this_time.shape[0]
@@ -696,20 +695,13 @@ def forward(
696695

697696
logits = deep_gemm.fp8_mqa_logits(
698697
q_fp8, k_cache, weights, ks, ke, max_seqlen_k=max_seqlen_k, clean_logits=False
699-
)
700-
701-
# To save GPU global memory usage
702-
assert logits.size() == (num_tokens, max_seqlen_k)
703-
tmp = paddle.full((num_tokens, num_tokens), float("-inf"))
704-
for i in range(num_tokens):
705-
tmp[i, ks[i] : ke[i]] = logits[i, : ke[i] - ks[i]]
706-
logits = tmp
698+
).contiguous()
707699

708700
radix_topk_ragged_transform(
709-
logits.contiguous(),
701+
logits,
710702
indexer_top_k,
711-
ks_topk, # self.offsets,
712-
ke - ks + 1, # mask.contiguous(),#self.lengths,
703+
ks, # self.offsets,# 初始K方向偏移,
704+
ke - ks, # self.lengths,# 表明当前q 关注的k有多长;
713705
None, # forward_meta.seq_lens_decoder,
714706
None, # forward_meta.batch_id_per_token,
715707
None,
@@ -740,20 +732,20 @@ def forward(
740732
schedule_metadata,
741733
self.max_model_len,
742734
clean_logits=True,
743-
)
735+
).contiguous()
744736

745737
radix_topk_ragged_transform(
746-
logits.contiguous(),
738+
logits,
747739
indexer_top_k,
748-
self.offsets, # unused
740+
forward_meta.cu_seqlens_q,
749741
self.lengths, # unused
750742
cache_seqlens,
751743
forward_meta.batch_id_per_token,
752744
forward_meta.block_tables,
753745
None, # self.buffer
754746
forward_meta.block_tables.shape[1],
755747
self.index_topk,
756-
1, # q_head
748+
1, # kv_head
757749
)
758750

759751
return indexer_top_k

tests/operators/test_radix_topk_accuracy.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,38 +36,40 @@ def setUp(self):
3636
def get_reference_topk(self, input_pd, lengths_pd, offsets_pd, top_k, q_num_heads):
3737
"""
3838
使用 paddle.topk 生成参考结果
39-
注意:算子输出的索引是 0-based 相对索引(不包含 offset
39+
注意:算子输出的索引是相对于 offsets 的偏移量(0-based 相对索引)
4040
4141
Args:
4242
input_pd: (num_rows, max_len)
4343
lengths_pd: (batch_size,) - 每个batch的长度
44-
offsets_pd: (num_rows,) - 每一行的偏移基点(未使用,仅保留参数兼容性)
44+
offsets_pd: (num_rows,) - 每一行的偏移基点
4545
top_k: k值
4646
q_num_heads: query head数量
4747
4848
Returns:
49-
ref_indices: (num_rows, top_k) - 参考索引(0-based 相对索引),长度不足的部分用-1填充
49+
ref_indices: (num_rows, top_k) - 参考索引(相对于 offset 的偏移),长度不足的部分用-1填充
5050
"""
5151
num_rows = input_pd.shape[0]
5252
ref_indices = paddle.full([num_rows, top_k], -1, dtype="int32")
53+
offsets = offsets_pd.numpy()
5354

5455
for row_idx in range(num_rows):
5556
batch_idx = row_idx // q_num_heads
5657
length = lengths_pd[batch_idx].item()
58+
offset = offsets[row_idx]
5759

5860
if length == 0:
5961
continue
6062

6163
row_data = input_pd[row_idx, :length]
6264

6365
if length <= top_k:
64-
# 长度不足top_k,按顺序返回所有索引(0-based
65-
ref_indices[row_idx, :length] = paddle.arange(0, length, dtype="int32")
66+
# 长度不足top_k,按顺序返回所有索引(相对于 offset
67+
ref_indices[row_idx, :length] = paddle.arange(offset, offset + length, dtype="int32")
6668
else:
6769
# 长度足够,使用 paddle.topk 获取最大的top_k个值的索引
6870
topk_vals, topk_inds = paddle.topk(row_data, top_k)
69-
# 直接使用 topk 返回的索引(0-based)
70-
ref_indices[row_idx, :top_k] = topk_inds
71+
# 加上 offset 作为基点
72+
ref_indices[row_idx, :top_k] = topk_inds + offset
7173

7274
return ref_indices
7375

@@ -171,41 +173,48 @@ def test_decode_mode(self):
171173
paddle.seed(2025)
172174

173175
batch_size = 2
174-
q_num_heads = 4
175-
num_rows = batch_size * q_num_heads
176+
kv_head = 1 # decode 模式下,每个 batch 只有一个新 token
177+
num_rows = batch_size * kv_head # = batch_size
176178
max_len = 1024
177179
top_k = 8
178180

179181
# 使用 paddle 构造数据
180182
input_pd = paddle.randn([num_rows, max_len], dtype="float32")
181-
offsets_pd = paddle.arange(num_rows, dtype="int32")
182-
lengths_pd = paddle.full([num_rows], 0, dtype="int32")
183-
seq_len_decoder_pd = paddle.randint(16, 128, [batch_size], dtype="int32")
184183

185-
# 生成 batch_id_per_token
186-
batch_id_per_token_pd = paddle.arange(num_rows, dtype="int32") // q_num_heads
184+
# 生成 cu_seqlens_q: 每个 batch 在打平的 query 中的偏移量
185+
# 在 decode 模式下,每个 batch 只有一个新 token,所以 cu_seqlens_q = [0, 1, 2, ..., batch_size]
186+
cu_seqlens_q_pd = paddle.concat(
187+
[
188+
paddle.zeros([1], dtype="int32"),
189+
paddle.cumsum(paddle.ones([batch_size], dtype="int32")).astype("int32"),
190+
],
191+
axis=0,
192+
)
187193

188-
# 调用算子
194+
lengths_pd = paddle.full([num_rows], 0, dtype="int32") # unused
195+
seq_len_decoder_pd = paddle.randint(16, 128, [batch_size], dtype="int32")
196+
197+
# 调用算子(不使用 block_tables,让它按照 prefill 模式类似的逻辑工作)
189198
output_indices = paddle.full([num_rows, top_k], -1, dtype="int32")
190199
radix_topk_ragged_transform(
191200
input_pd,
192201
output_indices,
193-
offsets_pd,
194-
lengths_pd,
202+
cu_seqlens_q_pd,
203+
lengths_pd, # unused
195204
seq_len_decoder_pd,
196-
batch_id_per_token_pd,
197-
None,
198-
None,
199-
0,
205+
None, # batch_id_per_token
206+
None, # block_tables
207+
None, # buffer
208+
0, # max_block_per_seq
200209
top_k,
201-
q_num_heads,
210+
kv_head,
202211
)
203212

204213
# Decode 模式下,长度 = seq_len_decoder + 1
205214
decode_lengths = seq_len_decoder_pd + 1
206215

207-
# 获取参考结果
208-
ref_indices = self.get_reference_topk(input_pd, decode_lengths, offsets_pd, top_k, q_num_heads)
216+
# 获取参考结果(注意:num_rows = batch_size * kv_head)
217+
ref_indices = self.get_reference_topk(input_pd, decode_lengths, cu_seqlens_q_pd, top_k, kv_head)
209218

210219
# 对比结果
211220
result = self.compare_indices(output_indices, ref_indices)

0 commit comments

Comments
 (0)