[PyTorch] Relax dimension constraints for using fused grouped MLP#2856
[PyTorch] Relax dimension constraints for using fused grouped MLP#2856ksivaman wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Greptile SummaryThis PR relaxes the dimension constraint for the fused grouped MLP path from modulo-256 to modulo-64, matching the actual kernel requirement, and fixes scale-factor (SF) view shapes to use ceiling division when tensor dimensions are not multiples of 128. Tests are updated to also cover Confidence Score: 5/5Safe to merge; changes are mechanically correct, internally consistent, and covered by updated tests. All findings are P2 or better. The constraint relaxation (256→64) matches the stated kernel requirement, the ceiling-division fix for SF views is mathematically necessary and correct, and a new hidden_size=64 test exercises the newly-enabled path. No logic errors, data corruption risks, or security concerns were found. No files require special attention.
|
| Filename | Overview |
|---|---|
| transformer_engine/pytorch/ops/_common.py | Constraint relaxed from % 256 to % 64 in both validate_grouped_mlp_dims and the fuse_grouped_mlp_ops sliding-window check; changes are symmetric and consistent. |
| transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py | All three scale-factor views (fc1 input, fc1 weight, fc2 weight) updated to ceiling division (x+127)//128, correctly handling dimensions that are multiples of 64 but not 128. |
| transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py | Same ceiling-division fix applied to fc2 dy scales, fc2 weight columnwise scales, and fc1 weight columnwise scales in the backward pass. |
| tests/pytorch/test_fusible_ops.py | hidden_size parametrized with (64, 256) to exercise the newly-allowed 64-aligned dimensions in the fused grouped MLP test. |
Flowchart
%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Input dims check\nfc1.in_features % 64 == 0\nfc1.out_features % 64 == 0\nfc2.in_features % 64 == 0\nfc2.out_features % 64 == 0] -->|pass| B[Fused Grouped MLP path]
A -->|fail| C[Unfused fallback]
B --> D[Forward: scale view shapes\nceil m/128 x ceil k/128\nx 32 x 4 x 4]
B --> E[Backward: scale view shapes\nceil m/128 x ceil k/128\nx 32 x 4 x 4]
D --> F[CuTeGEMM + SwiGLU kernel]
E --> G[CuTeGEMM dgrad/wgrad kernels]
Reviews (1): Last reviewed commit: "Reduce fused path dim constraint" | Re-trigger Greptile
|
/te-ci pytorch L0 |
Description
Previously we required all weight dims to be modulo256, this PR changes it to modulo64, which is the actual constraint from the kernel.
Type of change
Changes
Checklist: