[Optimization] enable trtllm_all_reduce fusion kernel in glm model#6660
[Optimization] enable trtllm_all_reduce fusion kernel in glm model#6660BingooYang wants to merge 24 commits intoPaddlePaddle:developfrom
Conversation
|
Thanks for your contribution! |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## develop #6660 +/- ##
==========================================
Coverage ? 73.54%
==========================================
Files ? 384
Lines ? 53711
Branches ? 8426
==========================================
Hits ? 39500
Misses ? 11525
Partials ? 2686
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
/re-run all-failed |
1 similar comment
|
/re-run all-failed |
e0fd641 to
b314228
Compare
|
/re-run all-failed |
2 similar comments
|
/re-run all-failed |
|
/re-run all-failed |
|
/re-run all-failed |
08d2f16 to
09cb26d
Compare
b7c4a47 to
be78caa
Compare
251517b to
6e13e3d
Compare
a43a5a6 to
f77a1b6
Compare
|
/re-run all-failed |
fastdeploy-bot
left a comment
There was a problem hiding this comment.
🤖 AI Code Review |
2026-04-14 00:56 CST
📋 Review 摘要
PR 概述:在 GLM 模型中启用 trtllm_allreduce fusion kernel,通过 flashinfer 库实现 allreduce + residual + RMSNorm 的融合算子优化
变更范围:model_executor/layers/、engine/、config/、tests/
影响面 Tag:[Optimization] [OP]
📝 PR 规范检查
PR 标题和描述符合规范。
问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 🔴 Bug | normalization.py:253 |
Fusion 失败时 assert 导致程序崩溃 |
| 🟡 建议 | flashinfer_comm_fusion.py:142 |
魔法数字 2048 硬编码 |
| 🟡 建议 | flashinfer_comm_fusion.py:206 |
cleanup 函数未被调用 |
总体评价
PR 实现了 flashinfer allreduce fusion 算子的接入,代码结构清晰,测试覆盖充分。但存在一个阻塞性 bug:当 fusion 失败时会触发 assert 导致程序崩溃,需要优雅降级到标准实现。建议修复后合并。
| norm_out = flashinfer_allreduce_residual_rmsnorm( | ||
| fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps | ||
| ) | ||
| assert norm_out[0] is not None, "Trtllm-all-reduce fusion failed!" |
There was a problem hiding this comment.
🔴 Bug 当 flashinfer_allreduce_residual_rmsnorm 返回 (None, None) 时(比如 flashinfer 不可用、workspace 未初始化等情况),assert 会抛出 AssertionError 导致程序崩溃。
flashinfer_allreduce_residual_rmsnorm 在以下情况会返回 (None, None):
- flashinfer 不可用(line 151-153)
- 单 GPU 场景(line 157-159)
- workspace 未初始化(line 169-170)
这是预期的 fallback 行为,不应该是错误。应该优雅降级到标准实现。
建议修复方式:
# enable trtllm all reduce fusion
elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
norm_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
)
if norm_out[0] is not None:
# fusion 成功,跳过后续 norm 计算
return norm_out[0].astype(x_dtype), norm_out[1].astype(residual_input_dtype)
# fusion 失败,继续使用标准实现
# assert 已移除,代码继续向下执行| residual: paddle.Tensor, | ||
| weight: paddle.Tensor, | ||
| eps: float = 1e-6, | ||
| max_token_num: int = 2048, |
There was a problem hiding this comment.
🟡 建议 max_token_num 默认值 2048 是硬编码的魔法数字,在多处重复出现(linear.py:957, normalization.py:249, flashinfer_comm_fusion.py:110,142)。
建议:
- 在
parallel_config中添加配置项flashinfer_fusion_max_token_num - 或定义常量
FLASHINFER_FUSION_MAX_TOKEN_NUM = 2048
这样可以统一管理和调整 fusion 的最大 token 数限制。
| return norm_out, residual_out | ||
|
|
||
|
|
||
| def cleanup_flashinfer_workspace(): |
There was a problem hiding this comment.
🟡 建议 cleanup_flashinfer_workspace 函数定义了但从未被调用,可能导致 workspace 资源泄漏。
建议在以下场景调用 cleanup:
- Worker 进程退出前
- 模型卸载时
- 显式禁用 fusion 时
例如,可以在 worker_process.py 的退出处理中添加清理逻辑。
Motivation
FD接入trtllm_allreduce_fusion算子
Modifications
Usage or Command
H卡和B卡本地测试均通过
python -m fastdeploy.entrypoints.openai.api_server --model /root/paddlejob/workspace/bingoo/model/GLM-4.5-Air --tensor-parallel-size 4 --port 8185 --max-num-batched-tokens 2048 --enable-flashinfer-allreduce-fusion
Accuracy Tests
python -m paddle.distributed.launch --gpus=0,1 ./FastDeploy/tests/layers/test_rms_allreduce_fusion.py
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.