Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ Operation fuser

.. autoapiclass:: transformer_engine.pytorch.ops.SReLU

.. autoapiclass:: transformer_engine.pytorch.ops.ScaledClampedQGeGLU

.. autoapiclass:: transformer_engine.pytorch.ops.ScaledSwiGLU

.. autoapiclass:: transformer_engine.pytorch.ops.SiLU
Expand Down
117 changes: 113 additions & 4 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2234,6 +2234,91 @@ def test_interleaved_scaled_swiglu(self):
scales_requires_grad=True,
)

@pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128)))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("scales_requires_grad", (False, True))
def test_scaled_clamped_qgeglu(
self,
*,
in_shape: Iterable[int],
glu_interleave_size: Optional[int] = None,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
input_requires_grad: bool,
scales_requires_grad: bool,
limit: float = 7.0,
alpha: float = 1.702,
) -> None:
"""ScaledClampedQGeGLU (clamped QGeGLU with post-scale)"""

# Tensor dims
out_shape = list(in_shape)
out_shape[-1] //= 2

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=input_requires_grad,
)
scales_ref, scales_test = make_reference_and_test_tensors(
in_shape[:-1],
test_dtype=dtype,
test_device=device,
requires_grad=scales_requires_grad,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch reference (matches :class:`ClampedSwiGLU` numerics)
x = x_ref
if glu_interleave_size is not None:
x = x.reshape(
-1,
in_shape[-1] // (2 * glu_interleave_size),
2,
glu_interleave_size,
)
x = x.transpose(1, 2)
x = x.reshape(in_shape)
x_glu, x_linear = x.chunk(2, dim=-1)
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
y = out_glu * (x_linear + 1)
y_ref = scales_ref.unsqueeze(-1) * y
if input_requires_grad or scales_requires_grad:
y_ref.backward(dy_ref)

op = te_ops.ScaledClampedQGeGLU(
glu_interleave_size=glu_interleave_size,
limit=limit,
alpha=alpha,
)
y_test = op(x_test, scales_test)
if input_requires_grad or scales_requires_grad:
y_test.backward(dy_test)

tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
assert_close(y_test, y_ref, **tols)
assert_close_grads(x_test, x_ref, **tols)
assert_close_grads(scales_test, scales_ref, **tols)

def test_interleaved_scaled_clamped_qgeglu(self):
"""ScaledClampedQGeGLU with block interleaved input format"""
self.test_scaled_clamped_qgeglu(
in_shape=(32, 192),
glu_interleave_size=32,
input_requires_grad=True,
scales_requires_grad=True,
)


class TestFusedOps:
"""Tests for fused operations"""
Expand Down Expand Up @@ -3249,6 +3334,7 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
@pytest.mark.parametrize("accumulate_into_main_grad", (False, True))
@pytest.mark.parametrize("glu_interleave_size", (None, 32))
@pytest.mark.parametrize("delay_wgrad_compute", (False, True))
@pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu"))
def test_grouped_mlp(
self,
*,
Expand All @@ -3264,8 +3350,9 @@ def test_grouped_mlp(
split_alignment: int = 256,
glu_interleave_size: Optional[int],
delay_wgrad_compute: bool,
activation: str,
) -> None:
"""GroupedLinear + ScaledSwiGLU + GroupedLinear"""
"""GroupedLinear + ScaledSwiGLU / ScaledClampedQGeGLU + GroupedLinear"""

# Split sizes
split_sizes = [split_alignment * (i) for i in range(group_size)]
Expand All @@ -3288,6 +3375,9 @@ def test_grouped_mlp(
if quantization == "mxfp8" and bias:
# Will be supported in future CUDNN release.
pytest.skip("Bias/dbias not yet supported in MXFP8 fused grouped MLP")
if quantization == "nvfp4" and activation == "scaled_clamped_qgeglu" and bias:
# TODO: ksivaman: Need to debug numerics for this case.
pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU")

# Random data
x_ref, x_test = make_reference_and_test_tensors(
Expand Down Expand Up @@ -3376,7 +3466,14 @@ def test_grouped_mlp(
x = x.transpose(1, 2)
x = x.reshape(-1, 2 * hidden_size)
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
if activation == "scaled_swiglu":
x = torch.nn.functional.silu(x1) * x2
else:
lim = torch.tensor(7.0, device=x1.device, dtype=x1.dtype)
geglu_alpha = 1.702
x1c = torch.minimum(x1, lim)
x2c = torch.clamp(x2, -lim, lim)
x = (x2c + 1) * (x1c * torch.sigmoid(geglu_alpha * x1c))
x = x * probs[group_idx].unsqueeze(-1)
x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx], bias=fc2_bs_ref[group_idx])
ys.append(x)
Expand All @@ -3385,6 +3482,11 @@ def test_grouped_mlp(

# Construct operations
recipe = make_recipe(quantization)
scaled_act = (
te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
if activation == "scaled_swiglu"
else te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size)
)
with te.quantized_model_init(enabled=with_quantization, recipe=recipe):
fc1 = te_ops.GroupedLinear(
group_size,
Expand Down Expand Up @@ -3412,7 +3514,7 @@ def test_grouped_mlp(
)
module = te_ops.Sequential(
fc1,
te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size),
scaled_act,
fc2,
)

Expand Down Expand Up @@ -3572,13 +3674,15 @@ def test_grouped_mlp(
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("single_grouped_weight", (False, True))
@pytest.mark.parametrize("accumulate_into_main_grad", (False, True))
@pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu"))
@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
def test_grouped_mlp_cuda_graph_safe_mxfp8(
self,
*,
dtype: torch.dtype,
single_grouped_weight: bool,
accumulate_into_main_grad: bool,
activation: str,
device: torch.device = "cuda",
group_size: int = 4,
hidden_size: int = 256,
Expand Down Expand Up @@ -3619,9 +3723,14 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8(
single_grouped_weight=single_grouped_weight,
accumulate_into_main_grad=accumulate_into_main_grad,
)
scaled_act = (
te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
if activation == "scaled_swiglu"
else te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size)
)
module = te_ops.Sequential(
fc1,
te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size),
scaled_act,
fc2,
)

Expand Down
27 changes: 19 additions & 8 deletions transformer_engine/pytorch/ops/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def get_fp8_meta_from_fp8_tensor(tensor: Float8Tensor) -> tuple[FP8TensorMeta, i
return fp8_meta, 0


def validate_grouped_mlp_dims(fc1, swiglu, fc2) -> None:
"""Validate FC1/SwiGLU/FC2 dimensions and interleave size for fused grouped MLP."""
def validate_grouped_mlp_dims(fc1, glu_op, fc2) -> None:
"""Validate FC1 / scaled GLU / FC2 dimensions for fused grouped MLP."""

if fc1.in_features % 256 != 0 or fc1.out_features % 256 != 0:
raise ValueError(
Expand All @@ -93,10 +93,10 @@ def validate_grouped_mlp_dims(fc1, swiglu, fc2) -> None:
f"and FC2 (num_groups={fc2.num_groups}, in_features={fc2.in_features}, "
f"out_features={fc2.out_features}) do not match."
)
if swiglu.glu_interleave_size != 32:
if glu_op.glu_interleave_size != 32:
raise ValueError(
"Fused kernel requires 32-wide GLU interleaving, "
f"but got glu_interleave_size={swiglu.glu_interleave_size}."
f"but got glu_interleave_size={glu_op.glu_interleave_size}."
)


Expand All @@ -106,7 +106,7 @@ def fuse_grouped_mlp_ops(
recipe,
fused_op_cls,
):
"""Sliding-window fusion for GroupedLinear + ScaledSwiGLU + GroupedLinear.
"""Sliding-window fusion for GroupedLinear + scaled GLU + GroupedLinear.

Parameters
----------
Expand All @@ -116,7 +116,9 @@ def fuse_grouped_mlp_ops(
Quantization recipe.
fused_op_cls : type
Fused operation class with ``is_supported()`` classmethod and
constructor accepting ``fc1``, ``swiglu``, ``fc2`` keyword args.
constructor accepting ``fc1``, ``glu_op``, ``fc2`` keyword args. The
``glu_op`` must be :class:`~transformer_engine.pytorch.ops.basic.swiglu.ScaledSwiGLU`
or :class:`~transformer_engine.pytorch.ops.basic.swiglu.ScaledClampedQGeGLU`.
May also expose ``is_fc1_bias_supported()`` and/or
``is_fc2_bias_supported()`` classmethods for bias eligibility.

Expand All @@ -125,7 +127,11 @@ def fuse_grouped_mlp_ops(
list of FusibleOperation
Updated operations with matched triples replaced by fused ops.
"""
from .basic import GroupedLinear, ScaledSwiGLU # pylint: disable=import-outside-toplevel
from .basic import ( # pylint: disable=import-outside-toplevel
GroupedLinear,
ScaledClampedQGeGLU,
ScaledSwiGLU,
)

if not fused_op_cls.is_supported():
return ops
Expand All @@ -146,10 +152,15 @@ def fuse_grouped_mlp_ops(
matches_pattern = True
if not (
isinstance(window[0], GroupedLinear)
and isinstance(window[1], ScaledSwiGLU)
and isinstance(window[1], (ScaledSwiGLU, ScaledClampedQGeGLU))
and isinstance(window[2], GroupedLinear)
):
matches_pattern = False
elif (
isinstance(window[1], ScaledClampedQGeGLU)
and abs(window[1]._clamped.alpha - 1.702) > 0.001
):
matches_pattern = False
elif window[0].num_groups != window[2].num_groups:
matches_pattern = False
elif (
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/ops/basic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@
from .reduce_scatter import ReduceScatter
from .reshape import Reshape
from .rmsnorm import RMSNorm
from .swiglu import ClampedSwiGLU, ScaledSwiGLU, SwiGLU
from .swiglu import ClampedSwiGLU, ScaledClampedQGeGLU, ScaledSwiGLU, SwiGLU
Loading
Loading