Skip to content

Commit 08d2f16

Browse files
committed
add max token num setting for trtllm_allreducefusion
1 parent b314228 commit 08d2f16

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

fastdeploy/model_executor/layers/normalization.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,12 @@ def forward(
244244
# enable trtllm all reduce fusion
245245
elif self.enable_all_reduce_fusion:
246246
norm_out = flashinfer_allreduce_residual_rmsnorm(
247-
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
247+
fd_config=self.fd_config,
248+
input_tensor=x,
249+
residual=residual_input,
250+
weight=self.weight,
251+
eps=self.eps,
252+
max_token_num=self.max_token_num,
248253
)
249254
else:
250255
norm_out = self.norm_func(

0 commit comments

Comments
 (0)