[PyTorch] Support scaled + clamped SwiGLU in te.ops and enable fused MXFP8 grouped MLP#2855
[PyTorch] Support scaled + clamped SwiGLU in te.ops and enable fused MXFP8 grouped MLP#2855ksivaman wants to merge 5 commits intoNVIDIA:mainfrom
te.ops and enable fused MXFP8 grouped MLP#2855Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Greptile SummaryThis PR introduces
Confidence Score: 4/5Safe to merge after adding the missing One P1 issue remains: the fusion eligibility check in transformer_engine/pytorch/ops/_common.py —
|
| Filename | Overview |
|---|---|
| transformer_engine/pytorch/ops/_common.py | Extended fuse_grouped_mlp_ops to recognize ScaledClampedQGeGLU; guards alpha against the cuDNN-required 1.702 but omits a matching guard on limit (also cuDNN-hardcoded at 7.0), which can cause silent numerics errors. |
| transformer_engine/pytorch/ops/basic/swiglu.py | Adds ScaledClampedQGeGLU as a clean _ScaledGLU subclass that delegates kernel calls to a private ClampedSwiGLU instance; refactored helpers _tex_clamped_swiglu_forward / _tex_clamped_dswiglu to avoid code duplication. |
| transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py | Widens FC1/activation/FC2 type to accept ScaledClampedQGeGLU; correctly stores _cudnn_act_func="geglu" at construction time and passes it to the cuDNN kernel call. |
| transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py | Parallel backward changes: accepts ScaledClampedQGeGLU, sets _cudnn_dact_func="dgeglu", and uses it in the kernel call — mirrors the forward path consistently. |
| tests/pytorch/test_fusible_ops.py | Adds unit tests for ScaledClampedQGeGLU (standalone + interleaved) and parametrizes test_grouped_mlp over both activations; nvfp4+geglu+bias path is skipped with an unresolved numerics TODO. |
| docs/api/pytorch.rst | Adds autoapiclass entry for ScaledClampedQGeGLU in the correct alphabetical location. |
| transformer_engine/pytorch/ops/basic/init.py | Exports ScaledClampedQGeGLU alongside existing swiglu variants. |
Flowchart
%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[fuse_grouped_mlp_ops] --> B{window = GroupedLinear + ? + GroupedLinear}
B -- yes --> C{isinstance ScaledSwiGLU or ScaledClampedQGeGLU}
C -- ScaledClampedQGeGLU --> D{alpha ≈ 1.702?}
D -- no --> E[skip fusion]
D -- yes --> F["⚠️ limit check MISSING\n(cuDNN expects limit=7.0)"]
F --> G{glu_interleave_size == 32?}
G -- no --> E
G -- yes --> H[Create fused op]
H --> I{ForwardGroupedMLP\n_cudnn_act_func}
H --> J{BackwardGroupedMLP\n_cudnn_dact_func}
I -- ScaledClampedQGeGLU --> K["act_func='geglu'"]
I -- ScaledSwiGLU --> L["act_func='swiglu'"]
J -- ScaledClampedQGeGLU --> M["act_func='dgeglu'"]
J -- ScaledSwiGLU --> N["act_func='dswiglu'"]
C -- ScaledSwiGLU --> G
Reviews (4): Last reviewed commit: "Merge branch 'main' into fused_mxfp8_mlp..." | Re-trigger Greptile
tests/pytorch/test_fusible_ops.py
Outdated
| if quantization == "nvfp4" and activation == "geglu" and bias: | ||
| # TODO: ksivaman: Need to debug numerics for this case. | ||
| pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") |
There was a problem hiding this comment.
Known numerics failure left unresolved
The skip guard documents a confirmed numerics defect ("Need to debug numerics for this case") for the nvfp4 + geglu + bias combination. That combination is part of the newly enabled activation="geglu" parametrize added in this very PR, so it represents a regression in the test matrix introduced here rather than a pre-existing gap. Merging with an acknowledged-but-undebugged wrong-output path may silently hide the root cause; a follow-up issue or a clear deferral note (e.g., a GitHub issue reference) would make this safer to track.
| class ScaledSwiGLU(BasicOperation): | ||
| r"""SwiGLU with post-scaling. | ||
| class _ScaledGLU(BasicOperation): | ||
| """SwiGLU-family activation with per-row scales (fused grouped MLP middle op).""" |
There was a problem hiding this comment.
Nit: We're being a buit weird and nonstandard by treating SwiGLU as the canonical GLU. Sigmoid GLU was first and torch.nn.GLU is a sigmoid GLU. Also, we shouldn't imply that they are specific to Transformer MLPs, since gated activations were used for RNNs long before Transformers were invented.
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
Tip: Greploop — Automatically fix all review issues by running Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal. |
|
/te-ci pytorch L0 |
timmoon10
left a comment
There was a problem hiding this comment.
Overall LGTM, with minor suggestions.
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci pytorch L0 |
Description
This PR adds support for the grouped MLP fused path via cuDNN frontend when the scaled clamped swiglu activation is used. This activation is misnamed to "geglu" in
cudnn-frontend.Type of change
Changes
ScaledClampedSwiGLUin TE sequential.swiglu.py.Checklist: