Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ def __init__(
self.pod_ip: str = None
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
self.disable_custom_all_reduce: bool = False
self.enable_flashinfer_allreduce_fusion: bool = False
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
Expand Down
11 changes: 11 additions & 0 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,11 @@ class EngineArgs:
Flag to disable the custom all-reduce kernel.
"""

enable_flashinfer_allreduce_fusion: bool = False
"""
Flag to enable all reduce fusion kernel in flashinfer.
"""

use_internode_ll_two_stage: bool = False
"""
Flag to use the internode_ll_two_stage kernel.
Expand Down Expand Up @@ -990,6 +995,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=EngineArgs.disable_custom_all_reduce,
help="Flag to disable custom all-reduce.",
)
parallel_group.add_argument(
"--enable-flashinfer-allreduce-fusion",
action="store_true",
default=EngineArgs.enable_flashinfer_allreduce_fusion,
help="Flag to enable all reduce fusion kernel in flashinfer.",
)
parallel_group.add_argument(
"--use-internode-ll-two-stage",
action="store_true",
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2503,6 +2503,7 @@ def _start_worker_service(self):
"moe_gate_fp32": self.cfg.model_config.moe_gate_fp32,
"enable_entropy": self.cfg.model_config.enable_entropy,
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
"enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion,
}
for worker_flag, value in worker_store_true_flag.items():
if value:
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ def _start_worker_service(self):
"enable_entropy": self.cfg.model_config.enable_entropy,
"ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens,
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
"enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion,
}
for worker_flag, value in worker_store_true_flag.items():
if value:
Expand Down
209 changes: 209 additions & 0 deletions fastdeploy/model_executor/layers/flashinfer_comm_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

from typing import Optional, Tuple

import paddle
import paddle.distributed as dist

from fastdeploy.config import FDConfig
from fastdeploy.model_executor.utils import has_flashinfer
from fastdeploy.utils import get_logger

logger = get_logger("flashinfer", "flashinfer.log")

_flashinfer_comm = None
_workspace_manager = None


def _get_flashinfer_comm():
"""Lazily import flashinfer.comm to avoid side effects at module load time."""
global _flashinfer_comm
if _flashinfer_comm is not None:
return _flashinfer_comm
if has_flashinfer():
try:
with paddle.use_compat_guard(enable=True, scope={"flashinfer"}):
import flashinfer.comm as comm

_flashinfer_comm = comm
except ImportError:
logger.warning("flashinfer.comm is not available, falling back to standard " "implementation")
return _flashinfer_comm


class FlashInferWorkspaceManager:
def __init__(self):
self.workspace_tensor = None
self.ipc_handles = None
self.world_size = None
self.rank = None
self.initialized = False

def initialize(
self,
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
group=None,
use_fp32_lamport: bool = False,
):
"""Initialize workspace"""
if self.initialized and self.world_size == world_size:
return

comm = _get_flashinfer_comm()
if comm is None:
logger.warning("FlashInfer comm not available, skipping workspace " "initialization")
return

self.cleanup()

self.ipc_handles, self.workspace_tensor = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
rank,
world_size,
max_token_num,
hidden_dim,
group=group,
use_fp32_lamport=use_fp32_lamport,
)

self.world_size = world_size
self.rank = rank
self.initialized = True

logger.info(f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}")

def cleanup(self):
"""Clean up workspace"""
if self.initialized and self.ipc_handles is not None:
try:
comm = _get_flashinfer_comm()
if comm is not None:
comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group())
except Exception as e:
logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
finally:
self.workspace_tensor = None
self.ipc_handles = None
self.initialized = False


_workspace_manager = FlashInferWorkspaceManager()


def ensure_workspace_initialized(
fd_config: FDConfig, max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False

This comment was marked as outdated.

):
"""Ensure workspace is initialized"""
comm = _get_flashinfer_comm()
if not has_flashinfer() or comm is None:
return False

assert fd_config is not None
world_size = fd_config.parallel_config.tensor_parallel_size
if world_size <= 1:
return False

rank = dist.get_rank()

if not _workspace_manager.initialized or _workspace_manager.world_size != world_size:
_workspace_manager.initialize(
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
use_fp32_lamport=use_fp32_lamport,
)

return _workspace_manager.initialized


def flashinfer_allreduce_residual_rmsnorm(
fd_config: FDConfig,
input_tensor: paddle.Tensor,
residual: paddle.Tensor,
weight: paddle.Tensor,
eps: float = 1e-6,
max_token_num: int = 2048,

This comment was marked as outdated.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 max_token_num 默认值 2048 是硬编码的魔法数字,在多处重复出现(linear.py:957, normalization.py:249, flashinfer_comm_fusion.py:110,142)。

建议:

  1. parallel_config 中添加配置项 flashinfer_fusion_max_token_num
  2. 或定义常量 FLASHINFER_FUSION_MAX_TOKEN_NUM = 2048

这样可以统一管理和调整 fusion 的最大 token 数限制。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fellow sglang

use_oneshot: Optional[bool] = None,
trigger_completion_at_end: bool = False,
fp32_acc: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""
Use FlashInfer's fused allreduce + residual + RMS norm operation
"""
comm = _get_flashinfer_comm()
if not has_flashinfer() or comm is None:
logger.debug("FlashInfer not available, falling back to standard " "implementation")
return None, None

assert fd_config is not None
world_size = fd_config.parallel_config.tensor_parallel_size
if world_size <= 1:
logger.debug("Single GPU, no need for allreduce fusion")
return None, None

assert input_tensor.shape[0] <= max_token_num

if not ensure_workspace_initialized(
fd_config=fd_config,
max_token_num=max_token_num,
hidden_dim=input_tensor.shape[-1],
use_fp32_lamport=(input_tensor.dtype == paddle.float32),
):
logger.debug("FlashInfer workspace not available")
return None, None

token_num, hidden_dim = input_tensor.shape

residual_out = paddle.empty_like(residual)
norm_out = paddle.empty_like(input_tensor)
# support empty tensor
if input_tensor.shape[0] == 0:
return norm_out, residual_out
comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
world_size=world_size,
world_rank=dist.get_rank(),
token_num=token_num,
hidden_dim=hidden_dim,
workspace_ptrs=_workspace_manager.workspace_tensor,
launch_with_pdl=True,
use_oneshot=use_oneshot,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=(comm.AllReduceFusionPattern.kARResidualRMSNorm),
allreduce_out=None,
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=eps,
scale_factor=None,
layout_code=None,
)

return norm_out, residual_out


def cleanup_flashinfer_workspace():

This comment was marked as outdated.

This comment was marked as outdated.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 cleanup_flashinfer_workspace 函数定义了但从未被调用,可能导致 workspace 资源泄漏。

建议在以下场景调用 cleanup:

  1. Worker 进程退出前
  2. 模型卸载时
  3. 显式禁用 fusion 时

例如,可以在 worker_process.py 的退出处理中添加清理逻辑。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sglang也没有清理

global _workspace_manager
if _workspace_manager is not None:
_workspace_manager.cleanup()
16 changes: 14 additions & 2 deletions fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,7 @@ def __init__(
skip_quant: bool = False,
weight_dtype: str = "",
layer_id: int = -1,
enable_all_reduce_fusion: bool = None,
):
"""
Initialize a linear layer with additional parameters for inference and quantization.
Expand All @@ -864,9 +865,17 @@ def __init__(
input_size (int): Number of input features. Defaults to None.
output_size (int): Number of output features. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
skip_quant (bool): Whether to skip quantization or not. Defaults to False.
enable_all_reduce_fusion (bool, optional): Whether to enable all-reduce fusion.
If None, it is determined by the config flag and prefix. Defaults to None.
"""
self.fd_config = fd_config
if enable_all_reduce_fusion is None:
self.enable_all_reduce_fusion = False
else:
self.enable_all_reduce_fusion = (
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and enable_all_reduce_fusion
)
self.ep_size = fd_config.parallel_config.expert_parallel_size
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.tp_group = fd_config.parallel_config.tp_group
Expand Down Expand Up @@ -944,7 +953,10 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:

out = self.quant_method.apply(self, x)

if self.reduce_results and self.tp_size > 1:
need_tp_all_reduce = (
self.reduce_results and self.tp_size > 1 and not (self.enable_all_reduce_fusion and out.shape[0] <= 2048)

This comment was marked as outdated.

)
if need_tp_all_reduce:
out = tensor_model_parallel_all_reduce(out, self.tp_group)

return out
Expand Down
11 changes: 11 additions & 0 deletions fastdeploy/model_executor/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_batch_invariant_mode_enabled,
rms_norm_batch_invariant,
)
from .flashinfer_comm_fusion import flashinfer_allreduce_residual_rmsnorm
from .utils import get_tensor, modules_to_convert


Expand Down Expand Up @@ -122,6 +123,10 @@ def __init__(
self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
self.tp_group = self.fd_config.parallel_config.tp_group
is_input_norm = prefix.endswith(".input_layernorm")
self.enable_all_reduce_fusion = (
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "post_attention_layernorm" in prefix
)

self.is_last_norm = prefix.endswith(".norm")
self.split_x = (
self.fd_config.parallel_config.use_sequence_parallel_moe
Expand Down Expand Up @@ -240,6 +245,12 @@ def forward(
norm_out = rms_norm(x, self.weight, self.eps)
return norm_out.astype(x_dtype), residual_out
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
# enable trtllm all reduce fusion
elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:

This comment was marked as outdated.

norm_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
)
assert norm_out[0] is not None, "Trtllm-all-reduce fusion failed!"

This comment was marked as outdated.

This comment was marked as outdated.

else:
if is_batch_invariant_mode_enabled():
# M-invariant path: per-row Triton kernel, no cross-row reduction
Expand Down
9 changes: 2 additions & 7 deletions fastdeploy/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# limitations under the License.
"""

import importlib
import importlib.util
import math
from enum import Enum
from typing import Callable, Optional
Expand All @@ -25,11 +23,12 @@

from fastdeploy import envs
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
from fastdeploy.model_executor.utils import set_weight_attrs
from fastdeploy.model_executor.utils import has_flashinfer, set_weight_attrs
from fastdeploy.platforms import current_platform

if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch

from fastdeploy.utils import get_logger

from ..moe import FusedMoE
Expand Down Expand Up @@ -59,10 +58,6 @@ def check_device_capability(num):
return False


def has_flashinfer():
return importlib.util.find_spec("flashinfer") is not None


def round_up(a, b):
return ((a + b - 1) // b) * b

Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/model_executor/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def __init__(
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group

self.use_ep = self.expert_parallel_size > 1
self.use_tp = self.tensor_parallel_size > 1

Expand Down Expand Up @@ -229,6 +228,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None
input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim,
output_size=fd_config.model_config.hidden_size,
layer_id=layer_id,
enable_all_reduce_fusion=fd_config.parallel_config.enable_flashinfer_allreduce_fusion,
)

self.attn = Attention(
Expand Down
6 changes: 6 additions & 0 deletions fastdeploy/model_executor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.
"""

import importlib
import importlib.util
import os
import re
from collections.abc import Mapping
Expand Down Expand Up @@ -553,6 +555,10 @@ def fn(loaded_weight_name, is_moe):
return fn


def has_flashinfer():
return importlib.util.find_spec("flashinfer") is not None


@cache
def get_sm_version():
if paddle.cuda.is_available():
Expand Down
Loading
Loading