-
Notifications
You must be signed in to change notification settings - Fork 736
[Optimization] enable trtllm_all_reduce fusion kernel in glm model #6660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
9bc0020
2e23487
cb10bba
17449a9
cf3b69d
b2396e9
a855786
2306881
1905ad6
d98d180
f36bb5d
d924c2a
a691e49
f20ecd5
022c8a4
8fa3574
ab838a3
fb08b6c
6e13e3d
f77a1b6
b785ba1
243ebf5
dcda582
93fad2c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,209 @@ | ||
| """ | ||
| # Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ | ||
|
|
||
| from typing import Optional, Tuple | ||
|
|
||
| import paddle | ||
| import paddle.distributed as dist | ||
|
|
||
| from fastdeploy.config import FDConfig | ||
| from fastdeploy.model_executor.utils import has_flashinfer | ||
| from fastdeploy.utils import get_logger | ||
|
|
||
| logger = get_logger("flashinfer", "flashinfer.log") | ||
|
|
||
| _flashinfer_comm = None | ||
| _workspace_manager = None | ||
|
|
||
|
|
||
| def _get_flashinfer_comm(): | ||
| """Lazily import flashinfer.comm to avoid side effects at module load time.""" | ||
| global _flashinfer_comm | ||
| if _flashinfer_comm is not None: | ||
| return _flashinfer_comm | ||
| if has_flashinfer(): | ||
| try: | ||
| with paddle.use_compat_guard(enable=True, scope={"flashinfer"}): | ||
| import flashinfer.comm as comm | ||
|
|
||
| _flashinfer_comm = comm | ||
| except ImportError: | ||
| logger.warning("flashinfer.comm is not available, falling back to standard " "implementation") | ||
| return _flashinfer_comm | ||
|
|
||
|
|
||
| class FlashInferWorkspaceManager: | ||
| def __init__(self): | ||
| self.workspace_tensor = None | ||
| self.ipc_handles = None | ||
| self.world_size = None | ||
| self.rank = None | ||
| self.initialized = False | ||
|
|
||
| def initialize( | ||
| self, | ||
| world_size: int, | ||
| rank: int, | ||
| max_token_num: int, | ||
| hidden_dim: int, | ||
| group=None, | ||
| use_fp32_lamport: bool = False, | ||
| ): | ||
| """Initialize workspace""" | ||
| if self.initialized and self.world_size == world_size: | ||
| return | ||
|
|
||
| comm = _get_flashinfer_comm() | ||
| if comm is None: | ||
| logger.warning("FlashInfer comm not available, skipping workspace " "initialization") | ||
| return | ||
|
|
||
| self.cleanup() | ||
|
|
||
| self.ipc_handles, self.workspace_tensor = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( | ||
| rank, | ||
| world_size, | ||
| max_token_num, | ||
| hidden_dim, | ||
| group=group, | ||
| use_fp32_lamport=use_fp32_lamport, | ||
| ) | ||
|
|
||
| self.world_size = world_size | ||
| self.rank = rank | ||
| self.initialized = True | ||
|
|
||
| logger.info(f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}") | ||
|
|
||
| def cleanup(self): | ||
| """Clean up workspace""" | ||
| if self.initialized and self.ipc_handles is not None: | ||
| try: | ||
| comm = _get_flashinfer_comm() | ||
| if comm is not None: | ||
| comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group()) | ||
| except Exception as e: | ||
| logger.warning(f"Failed to cleanup FlashInfer workspace: {e}") | ||
| finally: | ||
| self.workspace_tensor = None | ||
| self.ipc_handles = None | ||
| self.initialized = False | ||
|
|
||
|
|
||
| _workspace_manager = FlashInferWorkspaceManager() | ||
|
|
||
|
|
||
| def ensure_workspace_initialized( | ||
| fd_config: FDConfig, max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False | ||
| ): | ||
| """Ensure workspace is initialized""" | ||
| comm = _get_flashinfer_comm() | ||
| if not has_flashinfer() or comm is None: | ||
| return False | ||
|
|
||
| assert fd_config is not None | ||
| world_size = fd_config.parallel_config.tensor_parallel_size | ||
| if world_size <= 1: | ||
| return False | ||
|
|
||
| rank = dist.get_rank() | ||
|
|
||
| if not _workspace_manager.initialized or _workspace_manager.world_size != world_size: | ||
| _workspace_manager.initialize( | ||
| world_size=world_size, | ||
| rank=rank, | ||
| max_token_num=max_token_num, | ||
| hidden_dim=hidden_dim, | ||
| use_fp32_lamport=use_fp32_lamport, | ||
| ) | ||
|
|
||
| return _workspace_manager.initialized | ||
|
|
||
|
|
||
| def flashinfer_allreduce_residual_rmsnorm( | ||
| fd_config: FDConfig, | ||
| input_tensor: paddle.Tensor, | ||
| residual: paddle.Tensor, | ||
| weight: paddle.Tensor, | ||
| eps: float = 1e-6, | ||
| max_token_num: int = 2048, | ||
This comment was marked as outdated.
Sorry, something went wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 建议:
这样可以统一管理和调整 fusion 的最大 token 数限制。
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fellow sglang |
||
| use_oneshot: Optional[bool] = None, | ||
| trigger_completion_at_end: bool = False, | ||
| fp32_acc: bool = False, | ||
| ) -> Tuple[paddle.Tensor, paddle.Tensor]: | ||
| """ | ||
| Use FlashInfer's fused allreduce + residual + RMS norm operation | ||
| """ | ||
| comm = _get_flashinfer_comm() | ||
| if not has_flashinfer() or comm is None: | ||
| logger.debug("FlashInfer not available, falling back to standard " "implementation") | ||
| return None, None | ||
|
|
||
| assert fd_config is not None | ||
| world_size = fd_config.parallel_config.tensor_parallel_size | ||
| if world_size <= 1: | ||
| logger.debug("Single GPU, no need for allreduce fusion") | ||
| return None, None | ||
|
|
||
| assert input_tensor.shape[0] <= max_token_num | ||
|
|
||
| if not ensure_workspace_initialized( | ||
| fd_config=fd_config, | ||
| max_token_num=max_token_num, | ||
| hidden_dim=input_tensor.shape[-1], | ||
| use_fp32_lamport=(input_tensor.dtype == paddle.float32), | ||
| ): | ||
| logger.debug("FlashInfer workspace not available") | ||
| return None, None | ||
|
|
||
| token_num, hidden_dim = input_tensor.shape | ||
|
|
||
| residual_out = paddle.empty_like(residual) | ||
| norm_out = paddle.empty_like(input_tensor) | ||
| # support empty tensor | ||
| if input_tensor.shape[0] == 0: | ||
| return norm_out, residual_out | ||
| comm.trtllm_allreduce_fusion( | ||
| allreduce_in=input_tensor, | ||
| world_size=world_size, | ||
| world_rank=dist.get_rank(), | ||
| token_num=token_num, | ||
| hidden_dim=hidden_dim, | ||
| workspace_ptrs=_workspace_manager.workspace_tensor, | ||
| launch_with_pdl=True, | ||
| use_oneshot=use_oneshot, | ||
| trigger_completion_at_end=trigger_completion_at_end, | ||
| fp32_acc=fp32_acc, | ||
| pattern_code=(comm.AllReduceFusionPattern.kARResidualRMSNorm), | ||
| allreduce_out=None, | ||
| residual_in=residual, | ||
| residual_out=residual_out, | ||
| norm_out=norm_out, | ||
| quant_out=None, | ||
| scale_out=None, | ||
| rms_gamma=weight, | ||
| rms_eps=eps, | ||
| scale_factor=None, | ||
| layout_code=None, | ||
| ) | ||
|
|
||
| return norm_out, residual_out | ||
|
|
||
|
|
||
| def cleanup_flashinfer_workspace(): | ||
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 建议在以下场景调用 cleanup:
例如,可以在
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sglang也没有清理 |
||
| global _workspace_manager | ||
| if _workspace_manager is not None: | ||
| _workspace_manager.cleanup() | ||
This comment was marked as outdated.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.