Skip to content

Commit 56b761d

Browse files
authored
[Cherry-Pick][Speculative Decoding][BugFix] Fix apply repeat times penalty kernel and change spec default verify strategy(#7467) (#7468)
* fix repeat_time kernel and change default spec verify strategy * fix unit_test
1 parent 650d1e4 commit 56b761d

4 files changed

Lines changed: 81 additions & 83 deletions

File tree

custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu

Lines changed: 75 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
template <typename T>
1818
__global__ inline void min_length_logits_process(
19-
T *logits,
20-
const int64_t *cur_len,
21-
const int64_t *min_len,
22-
const int64_t *eos_token_id,
23-
const int *batch_id_per_token_output,
24-
const int *cu_seqlens_q_output,
19+
T* logits,
20+
const int64_t* cur_len,
21+
const int64_t* min_len,
22+
const int64_t* eos_token_id,
23+
const int* batch_id_per_token_output,
24+
const int* cu_seqlens_q_output,
2525
const int64_t token_num,
2626
const int64_t bs,
2727
const int64_t length,
@@ -46,12 +46,12 @@ __global__ inline void min_length_logits_process(
4646

4747
template <>
4848
__global__ inline void min_length_logits_process<half>(
49-
half *logits,
50-
const int64_t *cur_len,
51-
const int64_t *min_len,
52-
const int64_t *eos_token_id,
53-
const int *batch_id_per_token_output,
54-
const int *cu_seqlens_q_output,
49+
half* logits,
50+
const int64_t* cur_len,
51+
const int64_t* min_len,
52+
const int64_t* eos_token_id,
53+
const int* batch_id_per_token_output,
54+
const int* cu_seqlens_q_output,
5555
const int64_t token_num,
5656
const int64_t bs,
5757
const int64_t length,
@@ -74,11 +74,11 @@ __global__ inline void min_length_logits_process<half>(
7474
}
7575
}
7676

77-
__global__ void update_repeat_times(const int64_t *token_ids_all,
78-
const int64_t *prompt_lens,
79-
const int64_t *cur_len,
80-
int *repeat_times,
81-
const int *batch_id_per_token_output,
77+
__global__ void update_repeat_times(const int64_t* token_ids_all,
78+
const int64_t* prompt_lens,
79+
const int64_t* cur_len,
80+
int* repeat_times,
81+
const int* batch_id_per_token_output,
8282
const int64_t token_num,
8383
const int64_t bs,
8484
const int64_t length,
@@ -93,9 +93,9 @@ __global__ void update_repeat_times(const int64_t *token_ids_all,
9393
return;
9494
}
9595
int tid = threadIdx.x;
96-
const int64_t *pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi];
97-
int *repeat_times_now = repeat_times + token_idx * length;
98-
for (int i = tid; i < length_id; i += blockDim.x) {
96+
const int64_t* pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi];
97+
int* repeat_times_now = repeat_times + token_idx * length;
98+
for (int i = tid; i < cur_len[bi]; i += blockDim.x) {
9999
int64_t id = pre_ids_now[i];
100100
if (id < 0) break;
101101
atomicAdd(&repeat_times_now[id], 1);
@@ -104,13 +104,13 @@ __global__ void update_repeat_times(const int64_t *token_ids_all,
104104

105105
template <typename T>
106106
__global__ void update_value_by_repeat_times(
107-
const int *repeat_times,
108-
const T *penalty_scores,
109-
const T *frequency_score,
110-
const T *presence_score,
111-
const float *temperatures,
112-
T *logits,
113-
const int *batch_id_per_token_output,
107+
const int* repeat_times,
108+
const T* penalty_scores,
109+
const T* frequency_score,
110+
const T* presence_score,
111+
const float* temperatures,
112+
T* logits,
113+
const int* batch_id_per_token_output,
114114
const int64_t token_num,
115115
const int64_t bs,
116116
const int64_t length,
@@ -121,8 +121,8 @@ __global__ void update_value_by_repeat_times(
121121
if (bi < 0) return;
122122
if (bi >= bs) return;
123123
int tid = threadIdx.x;
124-
T *logits_now = logits + token_idx * length;
125-
const int *repeat_times_now = repeat_times + token_idx * length;
124+
T* logits_now = logits + token_idx * length;
125+
const int* repeat_times_now = repeat_times + token_idx * length;
126126
float alpha = static_cast<float>(penalty_scores[bi]);
127127
float beta = static_cast<float>(frequency_score[bi]);
128128
float gamma = static_cast<float>(presence_score[bi]);
@@ -138,10 +138,10 @@ __global__ void update_value_by_repeat_times(
138138
}
139139

140140
template <typename T>
141-
__global__ void ban_bad_words(T *logits,
142-
const int64_t *bad_tokens,
143-
const int64_t *bad_tokens_len,
144-
const int *batch_id_per_token_output,
141+
__global__ void ban_bad_words(T* logits,
142+
const int64_t* bad_tokens,
143+
const int64_t* bad_tokens_len,
144+
const int* batch_id_per_token_output,
145145
const int64_t token_num,
146146
const int64_t bs,
147147
const int64_t length,
@@ -153,8 +153,8 @@ __global__ void ban_bad_words(T *logits,
153153
if (bi < 0) return;
154154
if (bi >= bs) return;
155155
int tid = threadIdx.x;
156-
T *logits_now = logits + token_idx * length;
157-
const int64_t *bad_tokens_now = bad_tokens + bi * bad_words_length;
156+
T* logits_now = logits + token_idx * length;
157+
const int64_t* bad_tokens_now = bad_tokens + bi * bad_words_length;
158158
const int32_t bad_token_len =
159159
static_cast<int32_t>(min(bad_tokens_len[bi], bad_words_length));
160160
for (int i = tid; i < bad_token_len; i += blockDim.x) {
@@ -166,21 +166,21 @@ __global__ void ban_bad_words(T *logits,
166166

167167
template <paddle::DataType D>
168168
void token_penalty_multi_scores_kernel(
169-
const paddle::Tensor &token_ids_all,
170-
const paddle::Tensor &prompt_lens,
171-
const paddle::Tensor &logits,
172-
const paddle::Tensor &penalty_scores,
173-
const paddle::Tensor &frequency_score,
174-
const paddle::Tensor &presence_score,
175-
const paddle::Tensor &temperatures,
176-
const paddle::Tensor &bad_tokens,
177-
const paddle::Tensor &bad_tokens_len,
178-
const paddle::Tensor &cur_len,
179-
const paddle::Tensor &min_len,
180-
const paddle::Tensor &eos_token_id,
181-
const paddle::Tensor &seq_lens_this_time,
182-
const paddle::Tensor &batch_id_per_token_output,
183-
const paddle::Tensor &cu_seqlens_q_output,
169+
const paddle::Tensor& token_ids_all,
170+
const paddle::Tensor& prompt_lens,
171+
const paddle::Tensor& logits,
172+
const paddle::Tensor& penalty_scores,
173+
const paddle::Tensor& frequency_score,
174+
const paddle::Tensor& presence_score,
175+
const paddle::Tensor& temperatures,
176+
const paddle::Tensor& bad_tokens,
177+
const paddle::Tensor& bad_tokens_len,
178+
const paddle::Tensor& cur_len,
179+
const paddle::Tensor& min_len,
180+
const paddle::Tensor& eos_token_id,
181+
const paddle::Tensor& seq_lens_this_time,
182+
const paddle::Tensor& batch_id_per_token_output,
183+
const paddle::Tensor& cu_seqlens_q_output,
184184
const int max_seq_len) {
185185
typedef PDTraits<D> traits_;
186186
typedef typename traits_::DataType DataType_;
@@ -198,8 +198,7 @@ void token_penalty_multi_scores_kernel(
198198
int64_t end_length = eos_token_id.shape()[0];
199199
int block_size = (token_num + 32 - 1) / 32 * 32;
200200
min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
201-
reinterpret_cast<DataType_ *>(
202-
const_cast<data_t *>(logits.data<data_t>())),
201+
reinterpret_cast<DataType_*>(const_cast<data_t*>(logits.data<data_t>())),
203202
cur_len.data<int64_t>(),
204203
min_len.data<int64_t>(),
205204
eos_token_id.data<int64_t>(),
@@ -230,15 +229,15 @@ void token_penalty_multi_scores_kernel(
230229
update_value_by_repeat_times<DataType_>
231230
<<<token_num, block_size, 0, cu_stream>>>(
232231
repeat_times.data<int>(),
233-
reinterpret_cast<DataType_ *>(
234-
const_cast<data_t *>(penalty_scores.data<data_t>())),
235-
reinterpret_cast<DataType_ *>(
236-
const_cast<data_t *>(frequency_score.data<data_t>())),
237-
reinterpret_cast<DataType_ *>(
238-
const_cast<data_t *>(presence_score.data<data_t>())),
232+
reinterpret_cast<DataType_*>(
233+
const_cast<data_t*>(penalty_scores.data<data_t>())),
234+
reinterpret_cast<DataType_*>(
235+
const_cast<data_t*>(frequency_score.data<data_t>())),
236+
reinterpret_cast<DataType_*>(
237+
const_cast<data_t*>(presence_score.data<data_t>())),
239238
temperatures.data<float>(),
240-
reinterpret_cast<DataType_ *>(
241-
const_cast<data_t *>(logits.data<data_t>())),
239+
reinterpret_cast<DataType_*>(
240+
const_cast<data_t*>(logits.data<data_t>())),
242241
batch_id_per_token_output.data<int>(),
243242
token_num,
244243
bs,
@@ -247,8 +246,7 @@ void token_penalty_multi_scores_kernel(
247246
block_size = (length_bad_words + 32 - 1) / 32 * 32;
248247
block_size = min(block_size, 512);
249248
ban_bad_words<DataType_><<<token_num, block_size, 0, cu_stream>>>(
250-
reinterpret_cast<DataType_ *>(
251-
const_cast<data_t *>(logits.data<data_t>())),
249+
reinterpret_cast<DataType_*>(const_cast<data_t*>(logits.data<data_t>())),
252250
bad_tokens.data<int64_t>(),
253251
bad_tokens_len.data<int64_t>(),
254252
batch_id_per_token_output.data<int>(),
@@ -260,21 +258,21 @@ void token_penalty_multi_scores_kernel(
260258
}
261259

262260
void SpecTokenPenaltyMultiScores(
263-
const paddle::Tensor &token_ids_all,
264-
const paddle::Tensor &prompt_lens,
265-
const paddle::Tensor &logits,
266-
const paddle::Tensor &penalty_scores,
267-
const paddle::Tensor &frequency_scores,
268-
const paddle::Tensor &presence_scores,
269-
const paddle::Tensor &temperatures,
270-
const paddle::Tensor &bad_tokens,
271-
const paddle::Tensor &bad_tokens_len,
272-
const paddle::Tensor &cur_len,
273-
const paddle::Tensor &min_len,
274-
const paddle::Tensor &eos_token_id,
275-
const paddle::Tensor &seq_lens_this_time,
276-
const paddle::Tensor &batch_id_per_token_output,
277-
const paddle::Tensor &cu_seqlens_q_output,
261+
const paddle::Tensor& token_ids_all,
262+
const paddle::Tensor& prompt_lens,
263+
const paddle::Tensor& logits,
264+
const paddle::Tensor& penalty_scores,
265+
const paddle::Tensor& frequency_scores,
266+
const paddle::Tensor& presence_scores,
267+
const paddle::Tensor& temperatures,
268+
const paddle::Tensor& bad_tokens,
269+
const paddle::Tensor& bad_tokens_len,
270+
const paddle::Tensor& cur_len,
271+
const paddle::Tensor& min_len,
272+
const paddle::Tensor& eos_token_id,
273+
const paddle::Tensor& seq_lens_this_time,
274+
const paddle::Tensor& batch_id_per_token_output,
275+
const paddle::Tensor& cu_seqlens_q_output,
278276
const int max_seq_len) {
279277
switch (logits.type()) {
280278
case paddle::DataType::BFLOAT16: {

fastdeploy/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,7 @@ class SpeculativeConfig:
774774
"benchmark_mode": False,
775775
"enf_gen_phase_tag": False,
776776
"enable_draft_logprob": False,
777-
"verify_strategy": "topp",
777+
"verify_strategy": "target_match",
778778
"accept_policy": "normal",
779779
}
780780

tests/layers/test_speculative_sampler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,12 @@ def _create_default_sampling_metadata(
9797
return fake_sampling_metadata
9898

9999

100-
def _create_fd_config(max_model_len, method=None):
100+
def _create_fd_config(max_model_len, method=None, verify_strategy="topp"):
101101
model_config: Mock = Mock()
102102
model_config.max_model_len = max_model_len
103103
model_config.architectures = ["test_model"]
104104
model_config.mm_max_tokens_per_item = None
105-
speculative_config = SpeculativeConfig({"method": method} if method else {})
105+
speculative_config = SpeculativeConfig({"method": method, "verify_strategy": verify_strategy})
106106
graph_opt_config = GraphOptimizationConfig({})
107107
scheduler_config = SchedulerConfig({})
108108
parallel_config = ParallelConfig({})
@@ -187,7 +187,7 @@ def test_speculative_sampler():
187187
max_draft_token_num = 1
188188

189189
# Use ngram method for speculative decoding
190-
fd_config = _create_fd_config(max_model_len, method="ngram")
190+
fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp")
191191
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len)
192192
logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size)
193193
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
@@ -208,7 +208,7 @@ def test_speculative_sampler_logprobs():
208208
max_draft_token_num = 1
209209

210210
# Use ngram method for speculative decoding
211-
fd_config = _create_fd_config(max_model_len, method="ngram")
211+
fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp")
212212
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
213213
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0)
214214
sampling_metadata.share_inputs = share_inputs

tests/operators/test_speculate_get_token_penalty_multi_scores.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def update_repeat_times(
6161
token_ids_all_now = token_ids_all[bi]
6262
repeat_times_now = repeat_times[token_idx]
6363

64-
for i in range(length_id):
64+
for i in range(cur_len[bi]):
6565
id = token_ids_all_now[i]
6666
if id < 0:
6767
break

0 commit comments

Comments
 (0)