4646from fastdeploy .model_executor .layers .lm_head import ParallelLMHead
4747from fastdeploy .model_executor .layers .moe .moe import FusedMoE
4848from fastdeploy .model_executor .layers .normalization import LayerNorm , RMSNorm
49+ from fastdeploy .model_executor .layers .quantization .fp8_utils import (
50+ per_token_group_quant_fp8 ,
51+ )
4952from fastdeploy .model_executor .layers .rotary_embedding import (
5053 DeepseekScalingRotaryEmbedding ,
5154)
5962)
6063from fastdeploy .platforms import current_platform
6164
62- if current_platform .is_cuda () or current_platform .is_maca ():
63- from fastdeploy .model_executor .ops .gpu import (
64- get_position_ids_and_mask_encoder_batch ,
65- )
66-
67- from fastdeploy .model_executor .layers .quantization .fp8_utils import (
68- per_token_group_quant_fp8 ,
69- )
70- from fastdeploy .platforms import current_platform
71-
7265if current_platform .is_cuda ():
7366 from fastdeploy .model_executor .ops .gpu import (
7467 cp_gather_indexer_k_quant_cache ,
@@ -471,33 +464,6 @@ def forward(
471464 return output
472465
473466
474- def compute_slot_mapping (
475- block_tables : paddle .Tensor , # [num_reqs, max_blocks_per_req]
476- positions : paddle .Tensor , # [num_tokens] 每个token的位置
477- batch_id_per_token : paddle .Tensor , # [num_tokens] 每个token属于哪个请求
478- block_size : int ,
479- ) -> paddle .Tensor :
480- """
481- 计算 slot_mapping
482-
483- 公式: slot = block_id * block_size + offset_in_block
484- """
485- # 1. 计算每个 token 对应的 block 索引
486- block_idx = positions // block_size # [num_tokens]
487-
488- # 2. 从 block_tables 中查表获取 block_id
489- # block_tables[batch_id_per_token, block_idx]
490- block_ids = block_tables [batch_id_per_token , block_idx ] # [num_tokens]
491-
492- # 3. 计算在 block 内的偏移
493- block_offset = positions % block_size # [num_tokens]
494-
495- # 4. 计算 slot_mapping
496- slot_mapping = block_ids * block_size + block_offset
497-
498- return slot_mapping .cast (paddle .int64 )
499-
500-
501467import triton
502468import triton .language as tl
503469
@@ -686,17 +652,12 @@ def forward(
686652 weights = weights .unsqueeze (- 1 ) * q_scale * self .softmax_scale * self .index_n_heads ** - 0.5
687653 weights = weights .squeeze (- 1 )
688654
689- slot_mapping = compute_slot_mapping (
690- forward_meta .block_tables ,
691- forward_meta .position_ids ,
692- forward_meta .batch_id_per_token ,
693- 64 ,
694- )
695-
696655 indexer_top_k = paddle .full ([q_fp8 .shape [0 ], self .index_topk ], - 1 , dtype = "int32" )
697656
698657 # indexer write_cache
699- indexer_k_quant_and_cache (k , self .indexer_cache , slot_mapping , self .quant_block_size , self .scale_fmt )
658+ indexer_k_quant_and_cache (
659+ k , self .indexer_cache , forward_meta .slot_mapping , self .quant_block_size , self .scale_fmt
660+ )
700661
701662 from fastdeploy .model_executor .layers .quantization .fp8_utils import deep_gemm
702663
@@ -1172,12 +1133,6 @@ def __init__(self, fd_config: FDConfig):
11721133 num_embeddings = fd_config .model_config .vocab_size ,
11731134 prefix = "lm_head" ,
11741135 )
1175- self .position_ids_buffer = paddle .empty (
1176- [fd_config .scheduler_config .max_num_batched_tokens ], dtype = paddle .int32
1177- )
1178- self .mask_encoder_batch_buffer = paddle .empty (
1179- [fd_config .scheduler_config .max_num_batched_tokens , 1 ], dtype = paddle .int32
1180- )
11811136
11821137 @classmethod
11831138 def name (cls ):
@@ -1274,25 +1229,6 @@ def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta
12741229 logits [:, self .ori_vocab_size :] = - float ("inf" )
12751230 return logits
12761231
1277- def pre_process (self , forward_meta ):
1278- """ """
1279- seq_lens_encoder = forward_meta .seq_lens_encoder
1280- seq_lens_decoder = forward_meta .seq_lens_decoder
1281- seq_lens_this_time = forward_meta .seq_lens_this_time
1282-
1283- current_total_tokens = forward_meta .ids_remove_padding .shape [0 ]
1284- position_ids = self .position_ids_buffer [:current_total_tokens ]
1285- mask_encoder_batch = self .mask_encoder_batch_buffer [:current_total_tokens ]
1286-
1287- get_position_ids_and_mask_encoder_batch (
1288- seq_lens_encoder ,
1289- seq_lens_decoder ,
1290- seq_lens_this_time ,
1291- position_ids ,
1292- mask_encoder_batch ,
1293- )
1294- return position_ids , mask_encoder_batch
1295-
12961232 def empty_input_forward (self , forward_meta ):
12971233 """
12981234 empty_input_forward
@@ -1313,7 +1249,6 @@ def forward(
13131249 forward_meta : ForwardMeta ,
13141250 ):
13151251 ids_remove_padding = inputs ["ids_remove_padding" ]
1316- forward_meta .position_ids , forward_meta .mask_encoder_batch = self .pre_process (forward_meta )
13171252 hidden_states = self .model (
13181253 ids_remove_padding = ids_remove_padding ,
13191254 forward_meta = forward_meta ,
0 commit comments