Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3249,12 +3249,13 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
@pytest.mark.parametrize("accumulate_into_main_grad", (False, True))
@pytest.mark.parametrize("glu_interleave_size", (None, 32))
@pytest.mark.parametrize("delay_wgrad_compute", (False, True))
@pytest.mark.parametrize("hidden_size", (64, 256))
def test_grouped_mlp(
self,
*,
group_size: int = 4,
bias: bool,
hidden_size: int = 256,
hidden_size: int,
dtype: torch.dtype,
quantization: Optional[str],
single_grouped_weight: bool,
Expand Down
12 changes: 6 additions & 6 deletions transformer_engine/pytorch/ops/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ def get_fp8_meta_from_fp8_tensor(tensor: Float8Tensor) -> tuple[FP8TensorMeta, i
def validate_grouped_mlp_dims(fc1, swiglu, fc2) -> None:
"""Validate FC1/SwiGLU/FC2 dimensions and interleave size for fused grouped MLP."""

if fc1.in_features % 256 != 0 or fc1.out_features % 256 != 0:
if fc1.in_features % 64 != 0 or fc1.out_features % 64 != 0:
raise ValueError(
f"Unsupported dims for FC1 (num_groups={fc1.num_groups}, "
f"in_features={fc1.in_features}, out_features={fc1.out_features})."
)
if fc2.in_features % 256 != 0 or fc2.out_features % 256 != 0:
if fc2.in_features % 64 != 0 or fc2.out_features % 64 != 0:
raise ValueError(
f"Unsupported dims for FC2 (num_groups={fc2.num_groups}, "
f"in_features={fc2.in_features}, out_features={fc2.out_features})."
Expand Down Expand Up @@ -153,10 +153,10 @@ def fuse_grouped_mlp_ops(
elif window[0].num_groups != window[2].num_groups:
matches_pattern = False
elif (
window[0].in_features % 256 != 0
or window[0].out_features % 256 != 0
or window[2].in_features % 256 != 0
or window[2].out_features % 256 != 0
window[0].in_features % 64 != 0
or window[0].out_features % 64 != 0
or window[2].in_features % 64 != 0
or window[2].out_features % 64 != 0
):
matches_pattern = False
elif window[1].glu_interleave_size != 32:
Expand Down
12 changes: 6 additions & 6 deletions transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,8 @@ def fuser_backward(
fc2_dy_scales = fc2_dy_scales.view(dtype=torch.float8_e8m0fnu)
fc2_dy_scales = fc2_dy_scales.view(
1,
out_shape[0] // 128,
out_shape[1] // 128,
(out_shape[0] + 127) // 128,
(out_shape[1] + 127) // 128,
MXFP8_BLOCK_SCALING_SIZE,
4,
4,
Expand Down Expand Up @@ -451,8 +451,8 @@ def fuser_backward(
fc2_w_scales = fc2_weight_for_gemm.columnwise_scale_inv.view(dtype=torch.float8_e8m0fnu)
fc2_w_scales = fc2_w_scales.view(
num_groups,
fc2_weight_shape[1] // 128,
fc2_weight_shape[0] // 128,
(fc2_weight_shape[1] + 127) // 128,
(fc2_weight_shape[0] + 127) // 128,
MXFP8_BLOCK_SCALING_SIZE,
4,
4,
Expand Down Expand Up @@ -582,8 +582,8 @@ def fuser_backward(
)
fc1_w_scales = fc1_w_scales.view(
num_groups,
fc1_weight_shape[1] // 128,
fc1_weight_shape[0] // 128,
(fc1_weight_shape[1] + 127) // 128,
(fc1_weight_shape[0] + 127) // 128,
MXFP8_BLOCK_SCALING_SIZE,
4,
4,
Expand Down
12 changes: 6 additions & 6 deletions transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ def fuser_forward(
fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu)
fc1_x_scales = fc1_x_scales.view(
1,
in_shape[0] // 128,
in_shape[1] // 128,
(in_shape[0] + 127) // 128,
(in_shape[1] + 127) // 128,
MXFP8_BLOCK_SCALING_SIZE,
4,
4,
Expand Down Expand Up @@ -358,8 +358,8 @@ def fuser_forward(
fc1_w_scales = fc1_weight_for_gemm.scale_inv.view(dtype=torch.float8_e8m0fnu)
fc1_w_scales = fc1_w_scales.view(
num_groups,
fc1_weight_shape[0] // 128,
fc1_weight_shape[1] // 128,
(fc1_weight_shape[0] + 127) // 128,
(fc1_weight_shape[1] + 127) // 128,
MXFP8_BLOCK_SCALING_SIZE,
4,
4,
Expand Down Expand Up @@ -455,8 +455,8 @@ def fuser_forward(
fc2_w_scales = fc2_weight_for_gemm.scale_inv.view(dtype=torch.float8_e8m0fnu)
fc2_w_scales = fc2_w_scales.view(
num_groups,
fc2_weight_shape[0] // 128,
fc2_weight_shape[1] // 128,
(fc2_weight_shape[0] + 127) // 128,
(fc2_weight_shape[1] + 127) // 128,
MXFP8_BLOCK_SCALING_SIZE,
4,
4,
Expand Down
Loading