Skip to content

[PyTorch] Relax dimension constraints for using fused grouped MLP#2856

Open
ksivaman wants to merge 1 commit intoNVIDIA:mainfrom
ksivaman:relax_dim_constraint_for_fused_grouped_mlp
Open

[PyTorch] Relax dimension constraints for using fused grouped MLP#2856
ksivaman wants to merge 1 commit intoNVIDIA:mainfrom
ksivaman:relax_dim_constraint_for_fused_grouped_mlp

Conversation

@ksivaman
Copy link
Copy Markdown
Member

@ksivaman ksivaman commented Apr 8, 2026

Description

Previously we required all weight dims to be modulo256, this PR changes it to modulo64, which is the actual constraint from the kernel.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Enable fused grouped MLP path for 64 modulo dims as well.
  • Fix shapes when creating SF views to pass to cutedsl.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested a review from timmoon10 April 8, 2026 20:07
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 8, 2026

Greptile Summary

This PR relaxes the dimension constraint for the fused grouped MLP path from modulo-256 to modulo-64, matching the actual kernel requirement, and fixes scale-factor (SF) view shapes to use ceiling division when tensor dimensions are not multiples of 128. Tests are updated to also cover hidden_size=64.

Confidence Score: 5/5

Safe to merge; changes are mechanically correct, internally consistent, and covered by updated tests.

All findings are P2 or better. The constraint relaxation (256→64) matches the stated kernel requirement, the ceiling-division fix for SF views is mathematically necessary and correct, and a new hidden_size=64 test exercises the newly-enabled path. No logic errors, data corruption risks, or security concerns were found.

No files require special attention.

Vulnerabilities

No security concerns identified.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/_common.py Constraint relaxed from % 256 to % 64 in both validate_grouped_mlp_dims and the fuse_grouped_mlp_ops sliding-window check; changes are symmetric and consistent.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py All three scale-factor views (fc1 input, fc1 weight, fc2 weight) updated to ceiling division (x+127)//128, correctly handling dimensions that are multiples of 64 but not 128.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Same ceiling-division fix applied to fc2 dy scales, fc2 weight columnwise scales, and fc1 weight columnwise scales in the backward pass.
tests/pytorch/test_fusible_ops.py hidden_size parametrized with (64, 256) to exercise the newly-allowed 64-aligned dimensions in the fused grouped MLP test.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Input dims check\nfc1.in_features % 64 == 0\nfc1.out_features % 64 == 0\nfc2.in_features % 64 == 0\nfc2.out_features % 64 == 0] -->|pass| B[Fused Grouped MLP path]
    A -->|fail| C[Unfused fallback]
    B --> D[Forward: scale view shapes\nceil m/128 x ceil k/128\nx 32 x 4 x 4]
    B --> E[Backward: scale view shapes\nceil m/128 x ceil k/128\nx 32 x 4 x 4]
    D --> F[CuTeGEMM + SwiGLU kernel]
    E --> G[CuTeGEMM dgrad/wgrad kernels]
Loading

Reviews (1): Last reviewed commit: "Reduce fused path dim constraint" | Re-trigger Greptile

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ksivaman
Copy link
Copy Markdown
Member Author

ksivaman commented Apr 8, 2026

/te-ci pytorch L0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants