Skip to content

[Optimization] enable trtllm_all_reduce fusion kernel in glm model#6660

Open
BingooYang wants to merge 24 commits intoPaddlePaddle:developfrom
BingooYang:trtllm_allreduce
Open

[Optimization] enable trtllm_all_reduce fusion kernel in glm model#6660
BingooYang wants to merge 24 commits intoPaddlePaddle:developfrom
BingooYang:trtllm_allreduce

Conversation

@BingooYang
Copy link
Copy Markdown
Contributor

@BingooYang BingooYang commented Mar 4, 2026

Motivation

FD接入trtllm_allreduce_fusion算子

Modifications

  1. FD新增flashinfer allreduce fusion算子接入
  2. 更改GLM-Air-4.5模型组网结构接入trtllm_allreduce_fusion算子(默认不开启)
  3. 新增命令行参数--enable-flashinfer-allreduce-fusion,通过该参数来使能trtllm_allreduce_fusion
  4. 新增trtllm_allreduce_fusion算子单测
  5. 将def has_flashinfer()函数挪动到utils.py中
  6. 升级flashinfer版本到0.4.1.2(python接口修复、C++20兼容修复)
  7. 测试中增加删除flashinfer cache功能(CI机器上没有清理会有问题)
  8. import flashinfer改为lazy import方式,修复全局import和paddle compat同时存在导致模型加载时走到torch接口的问题
  9. 一些测试中补充--enable-flashinfer-allreduce-fusio设置

Usage or Command

H卡和B卡本地测试均通过
python -m fastdeploy.entrypoints.openai.api_server --model /root/paddlejob/workspace/bingoo/model/GLM-4.5-Air --tensor-parallel-size 4 --port 8185 --max-num-batched-tokens 2048 --enable-flashinfer-allreduce-fusion

Accuracy Tests

python -m paddle.distributed.launch --gpus=0,1 ./FastDeploy/tests/layers/test_rms_allreduce_fusion.py

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Mar 4, 2026

Thanks for your contribution!

@BingooYang BingooYang changed the title enable trtllm_all_reduce fusion kernel in glm model [Optimization] enable trtllm_all_reduce fusion kernel in glm model Mar 5, 2026
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Mar 5, 2026

Codecov Report

❌ Patch coverage is 87.96296% with 13 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@26c47c2). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...oy/model_executor/layers/flashinfer_comm_fusion.py 90.90% 4 Missing and 4 partials ⚠️
fastdeploy/model_executor/layers/normalization.py 40.00% 2 Missing and 1 partial ⚠️
fastdeploy/model_executor/layers/linear.py 60.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #6660   +/-   ##
==========================================
  Coverage           ?   73.54%           
==========================================
  Files              ?      384           
  Lines              ?    53711           
  Branches           ?     8426           
==========================================
  Hits               ?    39500           
  Misses             ?    11525           
  Partials           ?     2686           
Flag Coverage Δ
GPU 73.54% <87.96%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@BingooYang
Copy link
Copy Markdown
Contributor Author

/re-run all-failed

1 similar comment
@BingooYang
Copy link
Copy Markdown
Contributor Author

/re-run all-failed

@BingooYang
Copy link
Copy Markdown
Contributor Author

/re-run all-failed

2 similar comments
@BingooYang
Copy link
Copy Markdown
Contributor Author

/re-run all-failed

@BingooYang
Copy link
Copy Markdown
Contributor Author

/re-run all-failed

@BingooYang
Copy link
Copy Markdown
Contributor Author

/re-run all-failed

fastdeploy-bot

This comment was marked as outdated.

fastdeploy-bot

This comment was marked as outdated.

fastdeploy-bot

This comment was marked as outdated.

fastdeploy-bot

This comment was marked as outdated.

@BingooYang
Copy link
Copy Markdown
Contributor Author

/re-run all-failed

fastdeploy-bot

This comment was marked as outdated.

Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-14 00:56 CST

📋 Review 摘要

PR 概述:在 GLM 模型中启用 trtllm_allreduce fusion kernel,通过 flashinfer 库实现 allreduce + residual + RMSNorm 的融合算子优化

变更范围:model_executor/layers/、engine/、config/、tests/
影响面 Tag[Optimization] [OP]

📝 PR 规范检查

PR 标题和描述符合规范。

问题

级别 文件 概述
🔴 Bug normalization.py:253 Fusion 失败时 assert 导致程序崩溃
🟡 建议 flashinfer_comm_fusion.py:142 魔法数字 2048 硬编码
🟡 建议 flashinfer_comm_fusion.py:206 cleanup 函数未被调用

总体评价

PR 实现了 flashinfer allreduce fusion 算子的接入,代码结构清晰,测试覆盖充分。但存在一个阻塞性 bug:当 fusion 失败时会触发 assert 导致程序崩溃,需要优雅降级到标准实现。建议修复后合并。

norm_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
)
assert norm_out[0] is not None, "Trtllm-all-reduce fusion failed!"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bugflashinfer_allreduce_residual_rmsnorm 返回 (None, None) 时(比如 flashinfer 不可用、workspace 未初始化等情况),assert 会抛出 AssertionError 导致程序崩溃。

flashinfer_allreduce_residual_rmsnorm 在以下情况会返回 (None, None)

  1. flashinfer 不可用(line 151-153)
  2. 单 GPU 场景(line 157-159)
  3. workspace 未初始化(line 169-170)

这是预期的 fallback 行为,不应该是错误。应该优雅降级到标准实现。

建议修复方式

# enable trtllm all reduce fusion
elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
    norm_out = flashinfer_allreduce_residual_rmsnorm(
        fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
    )
    if norm_out[0] is not None:
        # fusion 成功,跳过后续 norm 计算
        return norm_out[0].astype(x_dtype), norm_out[1].astype(residual_input_dtype)
    # fusion 失败,继续使用标准实现
    # assert 已移除,代码继续向下执行

residual: paddle.Tensor,
weight: paddle.Tensor,
eps: float = 1e-6,
max_token_num: int = 2048,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 max_token_num 默认值 2048 是硬编码的魔法数字,在多处重复出现(linear.py:957, normalization.py:249, flashinfer_comm_fusion.py:110,142)。

建议:

  1. parallel_config 中添加配置项 flashinfer_fusion_max_token_num
  2. 或定义常量 FLASHINFER_FUSION_MAX_TOKEN_NUM = 2048

这样可以统一管理和调整 fusion 的最大 token 数限制。

return norm_out, residual_out


def cleanup_flashinfer_workspace():
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 cleanup_flashinfer_workspace 函数定义了但从未被调用,可能导致 workspace 资源泄漏。

建议在以下场景调用 cleanup:

  1. Worker 进程退出前
  2. 模型卸载时
  3. 显式禁用 fusion 时

例如,可以在 worker_process.py 的退出处理中添加清理逻辑。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants