Skip to content

Commit 650d1e4

Browse files
authored
[Cherry-Pick][Speculative Decoding] Add MTP logprob support for PD disaggregation (#7442) (#7464)
* support mtp logprob in pd * fix * fix * fix * fix xpu bugs
1 parent 185708b commit 650d1e4

File tree

7 files changed

+389
-78
lines changed

7 files changed

+389
-78
lines changed
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <stdio.h>
16+
#include <string.h>
17+
#include <sys/ipc.h>
18+
#include <sys/msg.h>
19+
#include <sys/types.h>
20+
#include "paddle/extension.h"
21+
#include "../../custom_ftok.h"
22+
#include "../speculate_logprob_msg.h"
23+
24+
#ifndef PD_BUILD_STATIC_OP
25+
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
26+
#endif
27+
28+
void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids,
29+
const paddle::Tensor& logprob_token_ids,
30+
const paddle::Tensor& logprob_scores,
31+
const paddle::Tensor& logprob_ranks,
32+
const paddle::Tensor& token_num_per_batch,
33+
const paddle::Tensor& cu_batch_token_offset,
34+
const paddle::Tensor& not_need_stop,
35+
const paddle::Tensor& seq_lens_decoder,
36+
const paddle::Tensor& prompt_lens,
37+
const paddle::Tensor& preempted_idx,
38+
int message_flag, // Target: 3, Draft: 4
39+
int64_t rank_id,
40+
bool save_each_rank) {
41+
if (!save_each_rank && rank_id > 0) {
42+
return;
43+
}
44+
45+
int max_draft_tokens = sampled_token_ids.shape()[1];
46+
int bsz = token_num_per_batch.shape()[0];
47+
48+
auto sampled_token_ids_cpu =
49+
sampled_token_ids.copy_to(paddle::CPUPlace(), false);
50+
auto logprob_token_ids_cpu =
51+
logprob_token_ids.copy_to(paddle::CPUPlace(), false);
52+
auto logprob_scores_cpu = logprob_scores.copy_to(paddle::CPUPlace(), false);
53+
auto logprob_ranks_cpu = logprob_ranks.copy_to(paddle::CPUPlace(), false);
54+
auto token_num_per_batch_cpu =
55+
token_num_per_batch.copy_to(paddle::CPUPlace(), false);
56+
auto cu_batch_token_offset_cpu =
57+
cu_batch_token_offset.copy_to(paddle::CPUPlace(), false);
58+
auto seq_lens_decoder_cpu =
59+
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
60+
auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true);
61+
int64_t* sampled_token_ids_data = sampled_token_ids_cpu.data<int64_t>();
62+
int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data<int64_t>();
63+
float* logprob_scores_data = logprob_scores_cpu.data<float>();
64+
int64_t* logprob_ranks_data = logprob_ranks_cpu.data<int64_t>();
65+
int* token_num_per_batch_data = token_num_per_batch_cpu.data<int>();
66+
int* cu_batch_token_offset_data = cu_batch_token_offset_cpu.data<int>();
67+
int* seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
68+
int64_t* prompt_lens_data = prompt_lens_cpu.data<int64_t>();
69+
const int32_t* preempted_idx_data = preempted_idx.data<int32_t>();
70+
71+
static struct msgdata msg_sed;
72+
int msg_queue_id = 1;
73+
if (const char* inference_msg_queue_id_env_p =
74+
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
75+
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
76+
int inference_msg_queue_id_from_env =
77+
std::stoi(inference_msg_queue_id_env_str);
78+
msg_queue_id = inference_msg_queue_id_from_env;
79+
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
80+
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
81+
<< inference_msg_queue_id_from_env << std::endl;
82+
#endif
83+
} else {
84+
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
85+
std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default."
86+
<< std::endl;
87+
#endif
88+
}
89+
int inference_msg_id_from_env = 1;
90+
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
91+
std::string inference_msg_id_env_str(inference_msg_id_env_p);
92+
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
93+
if (inference_msg_id_from_env == 2) {
94+
// 2 and -2 is perserve for no-output indication.
95+
throw std::runtime_error(
96+
" INFERENCE_MSG_ID cannot be 2, please use other number.");
97+
}
98+
if (inference_msg_id_from_env < 0) {
99+
throw std::runtime_error(
100+
" INFERENCE_MSG_ID cannot be negative, please use other "
101+
"number.");
102+
}
103+
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
104+
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
105+
<< std::endl;
106+
#endif
107+
} else {
108+
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
109+
std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
110+
<< std::endl;
111+
#endif
112+
}
113+
static key_t key = custom_ftok("/dev/shm", msg_queue_id);
114+
static int msgid = msgget(key, IPC_CREAT | 0666);
115+
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
116+
std::cout << "save_output_key: " << key << std::endl;
117+
std::cout << "save msgid: " << msgid << std::endl;
118+
#endif
119+
msg_sed.mtype = 1;
120+
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
121+
: -inference_msg_id_from_env;
122+
msg_sed.meta[1] = message_flag;
123+
msg_sed.meta[2] = bsz;
124+
int max_num_logprobs = logprob_token_ids.shape()[1];
125+
for (int i = 0; i < bsz; i++) {
126+
int cur_token_num;
127+
if (seq_lens_decoder_data[i] < prompt_lens_data[i] ||
128+
token_num_per_batch_data[i] == 0) {
129+
// chunk prefill or stop slots
130+
cur_token_num = 0;
131+
} else {
132+
cur_token_num = token_num_per_batch_data[i] + 1;
133+
}
134+
msg_sed.meta[3 + i] = cur_token_num;
135+
if (preempted_idx_data[i] == 1) {
136+
msg_sed.meta[3 + i] = -9;
137+
}
138+
139+
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
140+
int token_offset = cu_batch_token_offset_data[i];
141+
for (int j = 0; j < cur_token_num; j++) {
142+
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)];
143+
auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)];
144+
if (j == 0) {
145+
// first token has full logprobs
146+
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
147+
if (k == 0) {
148+
cur_tokens[k] =
149+
(int)sampled_token_ids_data[i * max_draft_tokens + j];
150+
cur_scores[k] =
151+
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
152+
k];
153+
} else if (k < max_num_logprobs) {
154+
// only for first token
155+
cur_tokens[k] =
156+
(int)logprob_token_ids_data[(token_offset + j) *
157+
(SPEC_LOGPROB_K + 1) +
158+
k];
159+
cur_scores[k] =
160+
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
161+
k];
162+
} else {
163+
cur_tokens[k] = -1;
164+
cur_scores[k] = 0.0;
165+
}
166+
}
167+
cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j];
168+
} else {
169+
// draft token only has token_id
170+
cur_tokens[0] = (int)sampled_token_ids_data[i * max_draft_tokens + j];
171+
}
172+
}
173+
}
174+
#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG
175+
std::cout << "msg data: " << std::endl;
176+
std::cout << "stop_flag: " << msg_sed.meta[0]
177+
<< ", message_flag: " << msg_sed.meta[1]
178+
<< ", bsz: " << msg_sed.meta[2] << std::endl;
179+
for (int i = 0; i < bsz; i++) {
180+
int cur_token_num = msg_sed.meta[3 + i];
181+
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
182+
std::cout << "batch " << i << " token_num: " << cur_token_num << std::endl;
183+
for (int j = 0; j < cur_token_num; j++) {
184+
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)];
185+
auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)];
186+
std::cout << "tokens: ";
187+
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
188+
std::cout << cur_tokens[k] << " ";
189+
}
190+
std::cout << std::endl;
191+
std::cout << "scores: ";
192+
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
193+
std::cout << cur_scores[k] << " ";
194+
}
195+
std::cout << std::endl;
196+
std::cout << "ranks: " << cur_batch_msg_sed->ranks[j] << std::endl;
197+
}
198+
}
199+
std::cout << std::endl;
200+
#endif
201+
if (msgsnd(msgid, &msg_sed, sizeof(msg_sed) - sizeof(long), 0) == -1) {
202+
printf("full msg buffer\n");
203+
}
204+
}
205+
206+
PD_BUILD_STATIC_OP(mtp_save_first_token_with_topk)
207+
.Inputs({"sampled_token_ids",
208+
"logprob_token_ids",
209+
"logprob_scores",
210+
"logprob_ranks",
211+
"token_num_per_batch",
212+
"cu_batch_token_offset",
213+
"not_need_stop",
214+
"seq_lens_decoder",
215+
"prompt_lens",
216+
"preempted_idx"})
217+
.Attrs({"message_flag: int", "rank_id: int64_t", "save_each_rank: bool"})
218+
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenWithTopK));

custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,12 @@
1919
#include <sys/types.h>
2020
#include "paddle/extension.h"
2121
#include "../custom_ftok.h"
22+
#include "speculate_logprob_msg.h"
2223

2324
#ifndef PD_BUILD_STATIC_OP
2425
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
2526
#endif
2627

27-
#define MAX_BSZ 512
28-
#define K 20
29-
#define MAX_DRAFT_TOKEN_NUM 6
30-
31-
struct batch_msgdata {
32-
int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)];
33-
float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)];
34-
int ranks[MAX_DRAFT_TOKEN_NUM];
35-
};
36-
37-
struct msgdata {
38-
long mtype;
39-
int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums
40-
batch_msgdata mtext[MAX_BSZ];
41-
};
42-
4328
void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
4429
const paddle::Tensor& output_scores,
4530
const paddle::Tensor& output_ranks,
@@ -93,22 +78,22 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
9378
output_tokens_data[1] = (int64_t)msg_rcv.meta[1];
9479
output_tokens_data[2] = (int64_t)msg_rcv.meta[2];
9580

96-
int output_tokens_offset = 3 + MAX_BSZ;
81+
int output_tokens_offset = 3 + SPEC_LOGPROB_MAX_BSZ;
9782
for (int i = 0; i < bsz; i++) {
9883
int cur_token_num = msg_rcv.meta[3 + i];
9984
output_tokens_data[3 + i] = (int64_t)cur_token_num; // batch_token_nums
10085

10186
auto* cur_output_token = output_tokens_data + output_tokens_offset +
102-
i * (MAX_DRAFT_TOKEN_NUM * (K + 1));
87+
i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1));
10388
auto* cur_output_score =
104-
output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (K + 1));
89+
output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1));
10590
auto* cur_batch_msg_rcv = &msg_rcv.mtext[i];
10691
for (int j = 0; j < cur_token_num; j++) {
10792
for (int k = 0; k < real_k + 1; k++) {
108-
cur_output_token[j * (K + 1) + k] =
109-
(int64_t)cur_batch_msg_rcv->tokens[j * (K + 1) + k];
110-
cur_output_score[j * (K + 1) + k] =
111-
cur_batch_msg_rcv->scores[j * (K + 1) + k];
93+
cur_output_token[j * (SPEC_LOGPROB_K + 1) + k] =
94+
(int64_t)cur_batch_msg_rcv->tokens[j * (SPEC_LOGPROB_K + 1) + k];
95+
cur_output_score[j * (SPEC_LOGPROB_K + 1) + k] =
96+
cur_batch_msg_rcv->scores[j * (SPEC_LOGPROB_K + 1) + k];
11297
}
11398
output_ranks_data[i * MAX_DRAFT_TOKEN_NUM + j] =
11499
(int64_t)cur_batch_msg_rcv->ranks[j];
@@ -124,17 +109,19 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens,
124109
std::cout << "batch " << i << " token_num: " << cur_token_num << std::endl;
125110
for (int j = 0; j < cur_token_num; j++) {
126111
std::cout << "tokens: ";
127-
for (int k = 0; k < K + 1; k++) {
112+
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
128113
std::cout << output_tokens_data[output_tokens_offset +
129-
i * MAX_DRAFT_TOKEN_NUM * (K + 1) +
130-
j * (K + 1) + k]
114+
i * MAX_DRAFT_TOKEN_NUM *
115+
(SPEC_LOGPROB_K + 1) +
116+
j * (SPEC_LOGPROB_K + 1) + k]
131117
<< " ";
132118
}
133119
std::cout << std::endl;
134120
std::cout << "scores: ";
135-
for (int k = 0; k < K + 1; k++) {
136-
std::cout << output_scores_data[i * MAX_DRAFT_TOKEN_NUM * (K + 1) +
137-
j * (K + 1) + k]
121+
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
122+
std::cout << output_scores_data[i * MAX_DRAFT_TOKEN_NUM *
123+
(SPEC_LOGPROB_K + 1) +
124+
j * (SPEC_LOGPROB_K + 1) + k]
138125
<< " ";
139126
}
140127
std::cout << std::endl;
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <stdio.h>
18+
#include <string.h>
19+
#include <sys/ipc.h>
20+
#include <sys/msg.h>
21+
#include <sys/types.h>
22+
#include "paddle/extension.h"
23+
24+
#define SPEC_LOGPROB_MAX_BSZ 512
25+
#define SPEC_LOGPROB_K 20
26+
#define MAX_DRAFT_TOKEN_NUM 6
27+
28+
struct batch_msgdata {
29+
int tokens[MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)];
30+
float scores[MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)];
31+
int ranks[MAX_DRAFT_TOKEN_NUM];
32+
};
33+
34+
struct msgdata {
35+
long mtype;
36+
// stop_flag, message_flag, bsz, batch_token_nums
37+
int meta[3 + SPEC_LOGPROB_MAX_BSZ];
38+
batch_msgdata mtext[SPEC_LOGPROB_MAX_BSZ];
39+
};

0 commit comments

Comments
 (0)