1616
1717template <typename T>
1818__global__ inline void min_length_logits_process (
19- T * logits,
20- const int64_t * cur_len,
21- const int64_t * min_len,
22- const int64_t * eos_token_id,
23- const int * batch_id_per_token_output,
24- const int * cu_seqlens_q_output,
19+ T* logits,
20+ const int64_t * cur_len,
21+ const int64_t * min_len,
22+ const int64_t * eos_token_id,
23+ const int * batch_id_per_token_output,
24+ const int * cu_seqlens_q_output,
2525 const int64_t token_num,
2626 const int64_t bs,
2727 const int64_t length,
@@ -46,12 +46,12 @@ __global__ inline void min_length_logits_process(
4646
4747template <>
4848__global__ inline void min_length_logits_process<half>(
49- half * logits,
50- const int64_t * cur_len,
51- const int64_t * min_len,
52- const int64_t * eos_token_id,
53- const int * batch_id_per_token_output,
54- const int * cu_seqlens_q_output,
49+ half* logits,
50+ const int64_t * cur_len,
51+ const int64_t * min_len,
52+ const int64_t * eos_token_id,
53+ const int * batch_id_per_token_output,
54+ const int * cu_seqlens_q_output,
5555 const int64_t token_num,
5656 const int64_t bs,
5757 const int64_t length,
@@ -74,11 +74,11 @@ __global__ inline void min_length_logits_process<half>(
7474 }
7575}
7676
77- __global__ void update_repeat_times (const int64_t * token_ids_all,
78- const int64_t * prompt_lens,
79- const int64_t * cur_len,
80- int * repeat_times,
81- const int * batch_id_per_token_output,
77+ __global__ void update_repeat_times (const int64_t * token_ids_all,
78+ const int64_t * prompt_lens,
79+ const int64_t * cur_len,
80+ int * repeat_times,
81+ const int * batch_id_per_token_output,
8282 const int64_t token_num,
8383 const int64_t bs,
8484 const int64_t length,
@@ -93,9 +93,9 @@ __global__ void update_repeat_times(const int64_t *token_ids_all,
9393 return ;
9494 }
9595 int tid = threadIdx .x ;
96- const int64_t * pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi];
97- int * repeat_times_now = repeat_times + token_idx * length;
98- for (int i = tid; i < length_id ; i += blockDim .x ) {
96+ const int64_t * pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi];
97+ int * repeat_times_now = repeat_times + token_idx * length;
98+ for (int i = tid; i < cur_len[bi] ; i += blockDim .x ) {
9999 int64_t id = pre_ids_now[i];
100100 if (id < 0 ) break ;
101101 atomicAdd (&repeat_times_now[id], 1 );
@@ -104,13 +104,13 @@ __global__ void update_repeat_times(const int64_t *token_ids_all,
104104
105105template <typename T>
106106__global__ void update_value_by_repeat_times (
107- const int * repeat_times,
108- const T * penalty_scores,
109- const T * frequency_score,
110- const T * presence_score,
111- const float * temperatures,
112- T * logits,
113- const int * batch_id_per_token_output,
107+ const int * repeat_times,
108+ const T* penalty_scores,
109+ const T* frequency_score,
110+ const T* presence_score,
111+ const float * temperatures,
112+ T* logits,
113+ const int * batch_id_per_token_output,
114114 const int64_t token_num,
115115 const int64_t bs,
116116 const int64_t length,
@@ -121,8 +121,8 @@ __global__ void update_value_by_repeat_times(
121121 if (bi < 0 ) return ;
122122 if (bi >= bs) return ;
123123 int tid = threadIdx .x ;
124- T * logits_now = logits + token_idx * length;
125- const int * repeat_times_now = repeat_times + token_idx * length;
124+ T* logits_now = logits + token_idx * length;
125+ const int * repeat_times_now = repeat_times + token_idx * length;
126126 float alpha = static_cast <float >(penalty_scores[bi]);
127127 float beta = static_cast <float >(frequency_score[bi]);
128128 float gamma = static_cast <float >(presence_score[bi]);
@@ -138,10 +138,10 @@ __global__ void update_value_by_repeat_times(
138138}
139139
140140template <typename T>
141- __global__ void ban_bad_words (T * logits,
142- const int64_t * bad_tokens,
143- const int64_t * bad_tokens_len,
144- const int * batch_id_per_token_output,
141+ __global__ void ban_bad_words (T* logits,
142+ const int64_t * bad_tokens,
143+ const int64_t * bad_tokens_len,
144+ const int * batch_id_per_token_output,
145145 const int64_t token_num,
146146 const int64_t bs,
147147 const int64_t length,
@@ -153,8 +153,8 @@ __global__ void ban_bad_words(T *logits,
153153 if (bi < 0 ) return ;
154154 if (bi >= bs) return ;
155155 int tid = threadIdx .x ;
156- T * logits_now = logits + token_idx * length;
157- const int64_t * bad_tokens_now = bad_tokens + bi * bad_words_length;
156+ T* logits_now = logits + token_idx * length;
157+ const int64_t * bad_tokens_now = bad_tokens + bi * bad_words_length;
158158 const int32_t bad_token_len =
159159 static_cast <int32_t >(min (bad_tokens_len[bi], bad_words_length));
160160 for (int i = tid; i < bad_token_len; i += blockDim .x ) {
@@ -166,21 +166,21 @@ __global__ void ban_bad_words(T *logits,
166166
167167template <paddle::DataType D>
168168void token_penalty_multi_scores_kernel (
169- const paddle::Tensor & token_ids_all,
170- const paddle::Tensor & prompt_lens,
171- const paddle::Tensor & logits,
172- const paddle::Tensor & penalty_scores,
173- const paddle::Tensor & frequency_score,
174- const paddle::Tensor & presence_score,
175- const paddle::Tensor & temperatures,
176- const paddle::Tensor & bad_tokens,
177- const paddle::Tensor & bad_tokens_len,
178- const paddle::Tensor & cur_len,
179- const paddle::Tensor & min_len,
180- const paddle::Tensor & eos_token_id,
181- const paddle::Tensor & seq_lens_this_time,
182- const paddle::Tensor & batch_id_per_token_output,
183- const paddle::Tensor & cu_seqlens_q_output,
169+ const paddle::Tensor& token_ids_all,
170+ const paddle::Tensor& prompt_lens,
171+ const paddle::Tensor& logits,
172+ const paddle::Tensor& penalty_scores,
173+ const paddle::Tensor& frequency_score,
174+ const paddle::Tensor& presence_score,
175+ const paddle::Tensor& temperatures,
176+ const paddle::Tensor& bad_tokens,
177+ const paddle::Tensor& bad_tokens_len,
178+ const paddle::Tensor& cur_len,
179+ const paddle::Tensor& min_len,
180+ const paddle::Tensor& eos_token_id,
181+ const paddle::Tensor& seq_lens_this_time,
182+ const paddle::Tensor& batch_id_per_token_output,
183+ const paddle::Tensor& cu_seqlens_q_output,
184184 const int max_seq_len) {
185185 typedef PDTraits<D> traits_;
186186 typedef typename traits_::DataType DataType_;
@@ -198,8 +198,7 @@ void token_penalty_multi_scores_kernel(
198198 int64_t end_length = eos_token_id.shape ()[0 ];
199199 int block_size = (token_num + 32 - 1 ) / 32 * 32 ;
200200 min_length_logits_process<<<1 , block_size, 0 , cu_stream>>> (
201- reinterpret_cast <DataType_ *>(
202- const_cast <data_t *>(logits.data <data_t >())),
201+ reinterpret_cast <DataType_*>(const_cast <data_t *>(logits.data <data_t >())),
203202 cur_len.data <int64_t >(),
204203 min_len.data <int64_t >(),
205204 eos_token_id.data <int64_t >(),
@@ -230,15 +229,15 @@ void token_penalty_multi_scores_kernel(
230229 update_value_by_repeat_times<DataType_>
231230 <<<token_num, block_size, 0 , cu_stream>>> (
232231 repeat_times.data <int >(),
233- reinterpret_cast <DataType_ *>(
234- const_cast <data_t *>(penalty_scores.data <data_t >())),
235- reinterpret_cast <DataType_ *>(
236- const_cast <data_t *>(frequency_score.data <data_t >())),
237- reinterpret_cast <DataType_ *>(
238- const_cast <data_t *>(presence_score.data <data_t >())),
232+ reinterpret_cast <DataType_*>(
233+ const_cast <data_t *>(penalty_scores.data <data_t >())),
234+ reinterpret_cast <DataType_*>(
235+ const_cast <data_t *>(frequency_score.data <data_t >())),
236+ reinterpret_cast <DataType_*>(
237+ const_cast <data_t *>(presence_score.data <data_t >())),
239238 temperatures.data <float >(),
240- reinterpret_cast <DataType_ *>(
241- const_cast <data_t *>(logits.data <data_t >())),
239+ reinterpret_cast <DataType_*>(
240+ const_cast <data_t *>(logits.data <data_t >())),
242241 batch_id_per_token_output.data <int >(),
243242 token_num,
244243 bs,
@@ -247,8 +246,7 @@ void token_penalty_multi_scores_kernel(
247246 block_size = (length_bad_words + 32 - 1 ) / 32 * 32 ;
248247 block_size = min (block_size, 512 );
249248 ban_bad_words<DataType_><<<token_num, block_size, 0 , cu_stream>>> (
250- reinterpret_cast <DataType_ *>(
251- const_cast <data_t *>(logits.data <data_t >())),
249+ reinterpret_cast <DataType_*>(const_cast <data_t *>(logits.data <data_t >())),
252250 bad_tokens.data <int64_t >(),
253251 bad_tokens_len.data <int64_t >(),
254252 batch_id_per_token_output.data <int >(),
@@ -260,21 +258,21 @@ void token_penalty_multi_scores_kernel(
260258}
261259
262260void SpecTokenPenaltyMultiScores (
263- const paddle::Tensor & token_ids_all,
264- const paddle::Tensor & prompt_lens,
265- const paddle::Tensor & logits,
266- const paddle::Tensor & penalty_scores,
267- const paddle::Tensor & frequency_scores,
268- const paddle::Tensor & presence_scores,
269- const paddle::Tensor & temperatures,
270- const paddle::Tensor & bad_tokens,
271- const paddle::Tensor & bad_tokens_len,
272- const paddle::Tensor & cur_len,
273- const paddle::Tensor & min_len,
274- const paddle::Tensor & eos_token_id,
275- const paddle::Tensor & seq_lens_this_time,
276- const paddle::Tensor & batch_id_per_token_output,
277- const paddle::Tensor & cu_seqlens_q_output,
261+ const paddle::Tensor& token_ids_all,
262+ const paddle::Tensor& prompt_lens,
263+ const paddle::Tensor& logits,
264+ const paddle::Tensor& penalty_scores,
265+ const paddle::Tensor& frequency_scores,
266+ const paddle::Tensor& presence_scores,
267+ const paddle::Tensor& temperatures,
268+ const paddle::Tensor& bad_tokens,
269+ const paddle::Tensor& bad_tokens_len,
270+ const paddle::Tensor& cur_len,
271+ const paddle::Tensor& min_len,
272+ const paddle::Tensor& eos_token_id,
273+ const paddle::Tensor& seq_lens_this_time,
274+ const paddle::Tensor& batch_id_per_token_output,
275+ const paddle::Tensor& cu_seqlens_q_output,
278276 const int max_seq_len) {
279277 switch (logits.type ()) {
280278 case paddle::DataType::BFLOAT16: {
0 commit comments