Skip to content

Commit b7c4a47

Browse files
committed
add max token num branch for trtllm_allreduce_fusion
1 parent 09cb26d commit b7c4a47

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

fastdeploy/model_executor/layers/linear.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,10 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
940940

941941
out = self.quant_method.apply(self, x)
942942

943-
if self.reduce_results and self.tp_size > 1 and not self.enable_all_reduce_fusion:
943+
need_tp_all_reduce = (
944+
self.reduce_results and self.tp_size > 1 and not (self.enable_all_reduce_fusion and out.shape[0] <= 2048)
945+
)
946+
if need_tp_all_reduce:
944947
out = tensor_model_parallel_all_reduce(out, self.tp_group)
945948

946949
return out

fastdeploy/model_executor/layers/normalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def forward(
242242
return norm_out.astype(x_dtype), residual_out
243243
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
244244
# enable trtllm all reduce fusion
245-
elif self.enable_all_reduce_fusion:
245+
elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
246246
norm_out = flashinfer_allreduce_residual_rmsnorm(
247247
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
248248
)

fastdeploy/model_executor/models/glm4_moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def __init__(
127127
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
128128
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
129129
self.tp_group = fd_config.parallel_config.tp_group
130-
self.enable_all_reduce_fusion = fd_config.parallel_config.enable_flashinfer_allreduce_fusion
131130
self.use_ep = self.expert_parallel_size > 1
132131
self.use_tp = self.tensor_parallel_size > 1
133132

0 commit comments

Comments
 (0)