@@ -36,38 +36,40 @@ def setUp(self):
3636 def get_reference_topk (self , input_pd , lengths_pd , offsets_pd , top_k , q_num_heads ):
3737 """
3838 使用 paddle.topk 生成参考结果
39- 注意:算子输出的索引是 0-based 相对索引(不包含 offset )
39+ 注意:算子输出的索引是相对于 offsets 的偏移量( 0-based 相对索引)
4040
4141 Args:
4242 input_pd: (num_rows, max_len)
4343 lengths_pd: (batch_size,) - 每个batch的长度
44- offsets_pd: (num_rows,) - 每一行的偏移基点(未使用,仅保留参数兼容性)
44+ offsets_pd: (num_rows,) - 每一行的偏移基点
4545 top_k: k值
4646 q_num_heads: query head数量
4747
4848 Returns:
49- ref_indices: (num_rows, top_k) - 参考索引(0-based 相对索引 ),长度不足的部分用-1填充
49+ ref_indices: (num_rows, top_k) - 参考索引(相对于 offset 的偏移 ),长度不足的部分用-1填充
5050 """
5151 num_rows = input_pd .shape [0 ]
5252 ref_indices = paddle .full ([num_rows , top_k ], - 1 , dtype = "int32" )
53+ offsets = offsets_pd .numpy ()
5354
5455 for row_idx in range (num_rows ):
5556 batch_idx = row_idx // q_num_heads
5657 length = lengths_pd [batch_idx ].item ()
58+ offset = offsets [row_idx ]
5759
5860 if length == 0 :
5961 continue
6062
6163 row_data = input_pd [row_idx , :length ]
6264
6365 if length <= top_k :
64- # 长度不足top_k,按顺序返回所有索引(0-based )
65- ref_indices [row_idx , :length ] = paddle .arange (0 , length , dtype = "int32" )
66+ # 长度不足top_k,按顺序返回所有索引(相对于 offset )
67+ ref_indices [row_idx , :length ] = paddle .arange (offset , offset + length , dtype = "int32" )
6668 else :
6769 # 长度足够,使用 paddle.topk 获取最大的top_k个值的索引
6870 topk_vals , topk_inds = paddle .topk (row_data , top_k )
69- # 直接使用 topk 返回的索引(0-based)
70- ref_indices [row_idx , :top_k ] = topk_inds
71+ # 加上 offset 作为基点
72+ ref_indices [row_idx , :top_k ] = topk_inds + offset
7173
7274 return ref_indices
7375
@@ -171,41 +173,48 @@ def test_decode_mode(self):
171173 paddle .seed (2025 )
172174
173175 batch_size = 2
174- q_num_heads = 4
175- num_rows = batch_size * q_num_heads
176+ kv_head = 1 # decode 模式下,每个 batch 只有一个新 token
177+ num_rows = batch_size * kv_head # = batch_size
176178 max_len = 1024
177179 top_k = 8
178180
179181 # 使用 paddle 构造数据
180182 input_pd = paddle .randn ([num_rows , max_len ], dtype = "float32" )
181- offsets_pd = paddle .arange (num_rows , dtype = "int32" )
182- lengths_pd = paddle .full ([num_rows ], 0 , dtype = "int32" )
183- seq_len_decoder_pd = paddle .randint (16 , 128 , [batch_size ], dtype = "int32" )
184183
185- # 生成 batch_id_per_token
186- batch_id_per_token_pd = paddle .arange (num_rows , dtype = "int32" ) // q_num_heads
184+ # 生成 cu_seqlens_q: 每个 batch 在打平的 query 中的偏移量
185+ # 在 decode 模式下,每个 batch 只有一个新 token,所以 cu_seqlens_q = [0, 1, 2, ..., batch_size]
186+ cu_seqlens_q_pd = paddle .concat (
187+ [
188+ paddle .zeros ([1 ], dtype = "int32" ),
189+ paddle .cumsum (paddle .ones ([batch_size ], dtype = "int32" )).astype ("int32" ),
190+ ],
191+ axis = 0 ,
192+ )
187193
188- # 调用算子
194+ lengths_pd = paddle .full ([num_rows ], 0 , dtype = "int32" ) # unused
195+ seq_len_decoder_pd = paddle .randint (16 , 128 , [batch_size ], dtype = "int32" )
196+
197+ # 调用算子(不使用 block_tables,让它按照 prefill 模式类似的逻辑工作)
189198 output_indices = paddle .full ([num_rows , top_k ], - 1 , dtype = "int32" )
190199 radix_topk_ragged_transform (
191200 input_pd ,
192201 output_indices ,
193- offsets_pd ,
194- lengths_pd ,
202+ cu_seqlens_q_pd ,
203+ lengths_pd , # unused
195204 seq_len_decoder_pd ,
196- batch_id_per_token_pd ,
197- None ,
198- None ,
199- 0 ,
205+ None , # batch_id_per_token
206+ None , # block_tables
207+ None , # buffer
208+ 0 , # max_block_per_seq
200209 top_k ,
201- q_num_heads ,
210+ kv_head ,
202211 )
203212
204213 # Decode 模式下,长度 = seq_len_decoder + 1
205214 decode_lengths = seq_len_decoder_pd + 1
206215
207- # 获取参考结果
208- ref_indices = self .get_reference_topk (input_pd , decode_lengths , offsets_pd , top_k , q_num_heads )
216+ # 获取参考结果(注意:num_rows = batch_size * kv_head)
217+ ref_indices = self .get_reference_topk (input_pd , decode_lengths , cu_seqlens_q_pd , top_k , kv_head )
209218
210219 # 对比结果
211220 result = self .compare_indices (output_indices , ref_indices )
0 commit comments