Skip to content

[PyTorch] Support scaled + clamped SwiGLU in te.ops and enable fused MXFP8 grouped MLP#2855

Open
ksivaman wants to merge 5 commits intoNVIDIA:mainfrom
ksivaman:fused_mxfp8_mlp_geglu
Open

[PyTorch] Support scaled + clamped SwiGLU in te.ops and enable fused MXFP8 grouped MLP#2855
ksivaman wants to merge 5 commits intoNVIDIA:mainfrom
ksivaman:fused_mxfp8_mlp_geglu

Conversation

@ksivaman
Copy link
Copy Markdown
Member

@ksivaman ksivaman commented Apr 8, 2026

Description

This PR adds support for the grouped MLP fused path via cuDNN frontend when the scaled clamped swiglu activation is used. This activation is misnamed to "geglu" in cudnn-frontend.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Implement ScaledClampedSwiGLU in TE sequential.
  • Small refactor in swiglu.py.
  • Enable grouped MLP fused path when using this activation via cuDNN frontend.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested a review from timmoon10 April 8, 2026 18:50
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 8, 2026

Greptile Summary

This PR introduces ScaledClampedQGeGLU (clamped SwiGLU with post-scale, mapping to cuDNN's misnamed "geglu" kernel) and wires it into the fused MXFP8 grouped MLP path alongside the existing ScaledSwiGLU.

  • Missing limit guard in fusion eligibility (_common.py lines 159–163): fuse_grouped_mlp_ops guards alpha (must be ≈ 1.702) against the cuDNN-hardcoded value, but does not check limit. A ScaledClampedQGeGLU(limit=5.0, glu_interleave_size=32) in an MXFP8 model would silently fuse with the cuDNN kernel that expects limit=7.0, producing wrong outputs with no error.

Confidence Score: 4/5

Safe to merge after adding the missing limit guard in fuse_grouped_mlp_ops; all other changes are consistent and well-tested.

One P1 issue remains: the fusion eligibility check in _common.py guards alpha but not limit, meaning a non-default limit value on ScaledClampedQGeGLU would silently route to the cuDNN kernel with the wrong clamp boundary. All other changes (new op class, kernel string selection, tests, docs) are correct and symmetric with the existing ScaledSwiGLU pattern.

transformer_engine/pytorch/ops/_common.py — fuse_grouped_mlp_ops needs a limit guard alongside the existing alpha guard.

Vulnerabilities

No security concerns identified.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/_common.py Extended fuse_grouped_mlp_ops to recognize ScaledClampedQGeGLU; guards alpha against the cuDNN-required 1.702 but omits a matching guard on limit (also cuDNN-hardcoded at 7.0), which can cause silent numerics errors.
transformer_engine/pytorch/ops/basic/swiglu.py Adds ScaledClampedQGeGLU as a clean _ScaledGLU subclass that delegates kernel calls to a private ClampedSwiGLU instance; refactored helpers _tex_clamped_swiglu_forward / _tex_clamped_dswiglu to avoid code duplication.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Widens FC1/activation/FC2 type to accept ScaledClampedQGeGLU; correctly stores _cudnn_act_func="geglu" at construction time and passes it to the cuDNN kernel call.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Parallel backward changes: accepts ScaledClampedQGeGLU, sets _cudnn_dact_func="dgeglu", and uses it in the kernel call — mirrors the forward path consistently.
tests/pytorch/test_fusible_ops.py Adds unit tests for ScaledClampedQGeGLU (standalone + interleaved) and parametrizes test_grouped_mlp over both activations; nvfp4+geglu+bias path is skipped with an unresolved numerics TODO.
docs/api/pytorch.rst Adds autoapiclass entry for ScaledClampedQGeGLU in the correct alphabetical location.
transformer_engine/pytorch/ops/basic/init.py Exports ScaledClampedQGeGLU alongside existing swiglu variants.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[fuse_grouped_mlp_ops] --> B{window = GroupedLinear + ? + GroupedLinear}
    B -- yes --> C{isinstance ScaledSwiGLU or ScaledClampedQGeGLU}
    C -- ScaledClampedQGeGLU --> D{alpha ≈ 1.702?}
    D -- no --> E[skip fusion]
    D -- yes --> F["⚠️ limit check MISSING\n(cuDNN expects limit=7.0)"]
    F --> G{glu_interleave_size == 32?}
    G -- no --> E
    G -- yes --> H[Create fused op]
    H --> I{ForwardGroupedMLP\n_cudnn_act_func}
    H --> J{BackwardGroupedMLP\n_cudnn_dact_func}
    I -- ScaledClampedQGeGLU --> K["act_func='geglu'"]
    I -- ScaledSwiGLU --> L["act_func='swiglu'"]
    J -- ScaledClampedQGeGLU --> M["act_func='dgeglu'"]
    J -- ScaledSwiGLU --> N["act_func='dswiglu'"]
    C -- ScaledSwiGLU --> G
Loading

Reviews (4): Last reviewed commit: "Merge branch 'main' into fused_mxfp8_mlp..." | Re-trigger Greptile

Comment on lines +3378 to +3380
if quantization == "nvfp4" and activation == "geglu" and bias:
# TODO: ksivaman: Need to debug numerics for this case.
pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Known numerics failure left unresolved

The skip guard documents a confirmed numerics defect ("Need to debug numerics for this case") for the nvfp4 + geglu + bias combination. That combination is part of the newly enabled activation="geglu" parametrize added in this very PR, so it represents a regression in the test matrix introduced here rather than a pre-existing gap. Merging with an acknowledged-but-undebugged wrong-output path may silently hide the root cause; a follow-up issue or a clear deferral note (e.g., a GitHub issue reference) would make this safer to track.

class ScaledSwiGLU(BasicOperation):
r"""SwiGLU with post-scaling.
class _ScaledGLU(BasicOperation):
"""SwiGLU-family activation with per-row scales (fused grouped MLP middle op)."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: We're being a buit weird and nonstandard by treating SwiGLU as the canonical GLU. Sigmoid GLU was first and torch.nn.GLU is a sigmoid GLU. Also, we shouldn't imply that they are specific to Transformer MLPs, since gated activations were used for RNNs long before Transformers were invented.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 8, 2026

Tip:

Greploop — Automatically fix all review issues by running /greploops in Claude Code. It iterates: fix, push, re-review, repeat until 5/5 confidence.

Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal.

@ksivaman
Copy link
Copy Markdown
Member Author

ksivaman commented Apr 8, 2026

/te-ci pytorch L0

timmoon10
timmoon10 previously approved these changes Apr 8, 2026
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, with minor suggestions.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Copy Markdown
Member Author

ksivaman commented Apr 8, 2026

/te-ci pytorch L0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants