@@ -159,6 +159,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
159159 if (use_neox_rotary_style) {
160160 pos_emb_type = " NEOX" ;
161161 } else if (rope_head_dim == head_dim / 2 ) {
162+ // vl model use this
162163 pos_emb_type = " HALF_HEAD_DIM" ;
163164 } else {
164165 pos_emb_type = " NORMAL" ;
@@ -984,7 +985,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
984985 return {block_attn_out};
985986}
986987
987- std::vector<paddle::Tensor> BlockAttn (
988+ std::vector<paddle::Tensor> BlockAttnFused (
988989 const paddle::Tensor& qkv,
989990 const paddle::Tensor& key_cache,
990991 const paddle::Tensor& value_cache,
@@ -1008,6 +1009,8 @@ std::vector<paddle::Tensor> BlockAttn(
10081009 const paddle::Tensor& decoder_context_len_cache,
10091010 const paddle::Tensor& decoder_batch_map,
10101011 const paddle::Tensor& prefix_len,
1012+ const paddle::Tensor& slot_mapping_enc,
1013+ const paddle::Tensor& slot_mapping_dec,
10111014 const paddle::optional<paddle::Tensor>& k_scales,
10121015 const paddle::optional<paddle::Tensor>& v_scales,
10131016 const paddle::optional<paddle::Tensor>& k_scales_inv,
@@ -1067,7 +1070,7 @@ std::vector<paddle::Tensor> BlockAttn(
10671070 } else if (cache_dtype == paddle::DataType::INT8) {
10681071 APPLY_KERNEL (paddle::bfloat16, int8_t , paddle::bfloat16);
10691072 } else {
1070- PD_THROW (" block_attn not support cache_dtype==%d" ,
1073+ PD_THROW (" block_attn_fused not support cache_dtype==%d" ,
10711074 static_cast <int >(cache_dtype));
10721075 return {};
10731076 }
@@ -1097,7 +1100,7 @@ std::vector<paddle::DataType> BlockAttnInferDtype(
10971100 return {qkv_dtype};
10981101}
10991102
1100- PD_BUILD_STATIC_OP (block_attn )
1103+ PD_BUILD_STATIC_OP (block_attn_fused )
11011104 .Inputs({" qkv" ,
11021105 " key_cache" ,
11031106 " value_cache" ,
@@ -1121,6 +1124,8 @@ PD_BUILD_STATIC_OP(block_attn)
11211124 " decoder_context_len_cache" ,
11221125 " decoder_batch_map" ,
11231126 " prefix_len" ,
1127+ " slot_mapping_enc" ,
1128+ " slot_mapping_dec" ,
11241129 paddle::Optional (" k_scales" ),
11251130 paddle::Optional (" v_scales" ),
11261131 paddle::Optional (" k_scales_inv" ),
@@ -1135,6 +1140,6 @@ PD_BUILD_STATIC_OP(block_attn)
11351140 paddle::Optional (" cachekv_signal_thread_cpu" )})
11361141 .Attrs({" use_neox_rotary_style:bool" , " rope_3d:bool" })
11371142 .Outputs({" block_attn_out" })
1138- .SetKernelFn(PD_KERNEL(BlockAttn ))
1143+ .SetKernelFn(PD_KERNEL(BlockAttnFused ))
11391144 .SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))
11401145 .SetInferDtypeFn(PD_INFER_DTYPE(BlockAttnInferDtype));
0 commit comments