@@ -1045,17 +1045,129 @@ def forward_cuda(
10451045 sampler_output .cu_batch_token_offset = cu_batch_token_offset .cpu ()
10461046 return sampler_output
10471047
1048- def forward_xpu (
1048+ def _normal_sample_xpu (
1049+ self ,
1050+ logits : paddle .Tensor ,
1051+ probs : paddle .Tensor ,
1052+ sampling_metadata : SamplingMetadata ,
1053+ share_inputs : List [paddle .Tensor ],
1054+ ) -> SamplerOutput :
1055+ """Normal sampling for NAIVE mode on XPU."""
1056+ top_p , top_k , topp_seed = padding_sampling_params (
1057+ sampling_metadata .top_p ,
1058+ sampling_metadata .top_k ,
1059+ sampling_metadata .seed ,
1060+ paddle .reshape (share_inputs ["seq_lens_this_time" ], shape = [- 1 ]),
1061+ paddle .reshape (share_inputs ["seq_lens_encoder" ], shape = [- 1 ]),
1062+ )
1063+ _ , next_tokens = top_k_top_p_sampling (
1064+ probs ,
1065+ top_p = top_p ,
1066+ top_k = top_k ,
1067+ top_k_list = sampling_metadata .top_k_list ,
1068+ topp_seed = topp_seed ,
1069+ )
1070+ real_bsz = share_inputs ["seq_lens_this_time" ].shape [0 ]
1071+ running_mask = (paddle .reshape (share_inputs ["seq_lens_this_time" ], shape = [- 1 ]) > 0 ).cast ("int32" )
1072+ share_inputs ["accept_tokens" ][:real_bsz , 0 ] = next_tokens .squeeze (- 1 )
1073+ share_inputs ["accept_num" ][:real_bsz ] = running_mask
1074+ return SamplerOutput (
1075+ sampled_token_ids = share_inputs ["accept_tokens" ],
1076+ logprobs_tensors = None ,
1077+ token_num_per_batch = share_inputs ["accept_num" ],
1078+ logits = logits ,
1079+ )
1080+
1081+ def _verify_and_sample_xpu (
10491082 self ,
10501083 logits : paddle .Tensor ,
1084+ probs : paddle .Tensor ,
10511085 sampling_metadata : SamplingMetadata ,
10521086 max_model_len : int ,
10531087 share_inputs : List [paddle .Tensor ],
10541088 accept_all_drafts : bool = False ,
10551089 reject_all_drafts : bool = False ,
1056- ) -> paddle .Tensor :
1057- from fastdeploy .model_executor .ops .xpu import speculate_verify , top_p_candidates
1090+ ) -> SamplerOutput :
1091+ """Verify draft tokens (MTP/Ngram mode) on XPU using verify_draft_tokens."""
1092+ from fastdeploy .model_executor .ops .xpu import (
1093+ top_p_candidates ,
1094+ verify_draft_tokens ,
1095+ )
1096+
1097+ target_tokens = None
1098+ candidate_ids , candidate_scores , candidate_lens = None , None , None
10581099
1100+ if self .verify_strategy == VerifyStrategy .TARGET_MATCH :
1101+ top_p , top_k , topp_seed = padding_sampling_params (
1102+ sampling_metadata .top_p ,
1103+ sampling_metadata .top_k ,
1104+ sampling_metadata .seed ,
1105+ paddle .reshape (share_inputs ["seq_lens_this_time" ], shape = [- 1 ]),
1106+ paddle .reshape (share_inputs ["seq_lens_encoder" ], shape = [- 1 ]),
1107+ )
1108+ _ , target_tokens = top_k_top_p_sampling (
1109+ probs ,
1110+ top_p = top_p ,
1111+ top_k = top_k ,
1112+ top_k_list = sampling_metadata .top_k_list ,
1113+ topp_seed = topp_seed ,
1114+ )
1115+ elif self .verify_strategy == VerifyStrategy .GREEDY :
1116+ target_tokens = paddle .argmax (probs , axis = - 1 )
1117+ elif self .verify_strategy == VerifyStrategy .TOPP :
1118+ candidate_scores , candidate_ids , candidate_lens = top_p_candidates (
1119+ probs ,
1120+ sampling_metadata .top_p ,
1121+ share_inputs ["batch_id_per_token_output" ],
1122+ self .speculative_max_candidate_len ,
1123+ max_model_len ,
1124+ )
1125+ else :
1126+ raise ValueError (f"Unknown verify strategy: { self .verify_strategy } " )
1127+
1128+ final_accept_all = self .config_accept_all or accept_all_drafts
1129+ final_reject_all = self .config_reject_all or reject_all_drafts or self .speculative_benchmark_mode
1130+
1131+ verify_draft_tokens (
1132+ share_inputs ["accept_tokens" ],
1133+ share_inputs ["accept_num" ],
1134+ share_inputs ["draft_tokens" ],
1135+ target_tokens ,
1136+ candidate_ids ,
1137+ candidate_scores ,
1138+ candidate_lens ,
1139+ sampling_metadata .top_p ,
1140+ share_inputs ["stop_flags" ],
1141+ share_inputs ["seq_lens_encoder" ],
1142+ share_inputs ["seq_lens_this_time" ],
1143+ sampling_metadata .eos_token_ids ,
1144+ share_inputs ["is_block_step" ],
1145+ share_inputs ["cu_seqlens_q_output" ],
1146+ share_inputs ["reasoning_status" ],
1147+ share_inputs ["max_dec_len" ],
1148+ share_inputs ["step_idx" ],
1149+ max_model_len ,
1150+ self .speculative_verify_window ,
1151+ self .verify_strategy .value ,
1152+ final_reject_all ,
1153+ final_accept_all ,
1154+ )
1155+ return SamplerOutput (
1156+ sampled_token_ids = share_inputs ["accept_tokens" ],
1157+ logprobs_tensors = None ,
1158+ token_num_per_batch = share_inputs ["accept_num" ],
1159+ logits = logits ,
1160+ )
1161+
1162+ def forward_xpu (
1163+ self ,
1164+ logits : paddle .Tensor ,
1165+ sampling_metadata : SamplingMetadata ,
1166+ max_model_len : int ,
1167+ share_inputs : List [paddle .Tensor ],
1168+ accept_all_drafts : bool = False ,
1169+ reject_all_drafts : bool = False ,
1170+ ) -> SamplerOutput :
10591171 logits = apply_speculative_penalty_multi_scores (
10601172 sampling_metadata .token_ids_all ,
10611173 sampling_metadata .prompt_lens ,
@@ -1078,61 +1190,19 @@ def forward_xpu(
10781190
10791191 probs = F .softmax (logits )
10801192
1081- top_p , top_k , topp_seed = padding_sampling_params (
1082- sampling_metadata .top_p ,
1083- sampling_metadata .top_k ,
1084- sampling_metadata .seed ,
1085- paddle .reshape (share_inputs ["seq_lens_this_time" ], shape = [- 1 ]),
1086- paddle .reshape (share_inputs ["seq_lens_encoder" ], shape = [- 1 ]),
1087- )
1088- _ , sampled_token_ids = top_k_top_p_sampling (
1089- probs , top_p = top_p , top_k = top_k , top_k_list = sampling_metadata .top_k_list , topp_seed = topp_seed
1090- )
1091-
1092- verify_scores , verify_tokens , actual_candidate_len = top_p_candidates (
1093- probs ,
1094- sampling_metadata .top_p ,
1095- share_inputs ["batch_id_per_token_output" ],
1096- self .speculative_max_candidate_len ,
1097- max_model_len ,
1098- )
1099-
1100- speculate_verify (
1101- sampled_token_ids ,
1102- share_inputs ["accept_tokens" ],
1103- share_inputs ["accept_num" ],
1104- share_inputs ["step_idx" ],
1105- share_inputs ["stop_flags" ],
1106- share_inputs ["seq_lens_encoder" ],
1107- share_inputs ["seq_lens_decoder" ],
1108- share_inputs [
1109- "draft_tokens"
1110- ], # Both input and output, need to write the last 1 token accepted to position 0.
1111- share_inputs ["seq_lens_this_time" ],
1112- verify_tokens ,
1113- verify_scores ,
1114- share_inputs ["max_dec_len" ],
1115- sampling_metadata .eos_token_ids ,
1116- share_inputs ["is_block_step" ],
1117- share_inputs ["cu_seqlens_q_output" ],
1118- actual_candidate_len ,
1119- share_inputs ["actual_draft_token_num" ],
1120- sampling_metadata .top_p ,
1121- max_model_len ,
1122- self .speculative_verify_window ,
1123- True , # enable_topp
1124- (self .speculative_benchmark_mode or reject_all_drafts ),
1125- accept_all_drafts ,
1126- )
1127- # TODO(chenhuan09): support return logprobs
1128- token_ids = share_inputs ["accept_tokens" ]
1129- sampler_output = SamplerOutput (
1130- sampled_token_ids = token_ids ,
1131- logprobs_tensors = None ,
1132- token_num_per_batch = share_inputs ["accept_num" ],
1133- cu_batch_token_offset = None ,
1134- )
1135- return sampler_output
1193+ is_naive = self .spec_method is None or self .spec_method == SpecMethod .NAIVE
1194+ if is_naive :
1195+ return self ._normal_sample_xpu (logits , probs , sampling_metadata , share_inputs )
1196+ else :
1197+ return self ._verify_and_sample_xpu (
1198+ logits ,
1199+ probs ,
1200+ sampling_metadata ,
1201+ max_model_len ,
1202+ share_inputs ,
1203+ accept_all_drafts ,
1204+ reject_all_drafts ,
1205+ )
11361206
11371207
11381208class MTPSampler (nn .Layer ):
0 commit comments