diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 1fe4f19990..3217d29c3b 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -221,6 +221,8 @@ Operation fuser .. autoapiclass:: transformer_engine.pytorch.ops.SReLU +.. autoapiclass:: transformer_engine.pytorch.ops.ScaledClampedQGeGLU + .. autoapiclass:: transformer_engine.pytorch.ops.ScaledSwiGLU .. autoapiclass:: transformer_engine.pytorch.ops.SiLU diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 75d450b46b..795cbf3452 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -18,6 +18,9 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te import transformer_engine.pytorch.ops as te_ops +from transformer_engine.pytorch.ops._common import ( + _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu, +) from transformer_engine.pytorch.ops.fused import ( BackwardActivationBias, @@ -2234,6 +2237,91 @@ def test_interleaved_scaled_swiglu(self): scales_requires_grad=True, ) + @pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128))) + @pytest.mark.parametrize("input_requires_grad", (False, True)) + @pytest.mark.parametrize("scales_requires_grad", (False, True)) + def test_scaled_clamped_qgeglu( + self, + *, + in_shape: Iterable[int], + glu_interleave_size: Optional[int] = None, + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + input_requires_grad: bool, + scales_requires_grad: bool, + limit: float = 7.0, + alpha: float = 1.702, + ) -> None: + """ScaledClampedQGeGLU (clamped QGeGLU with post-scale)""" + + # Tensor dims + out_shape = list(in_shape) + out_shape[-1] //= 2 + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=input_requires_grad, + ) + scales_ref, scales_test = make_reference_and_test_tensors( + in_shape[:-1], + test_dtype=dtype, + test_device=device, + requires_grad=scales_requires_grad, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch reference (matches :class:`ClampedSwiGLU` numerics) + x = x_ref + if glu_interleave_size is not None: + x = x.reshape( + -1, + in_shape[-1] // (2 * glu_interleave_size), + 2, + glu_interleave_size, + ) + x = x.transpose(1, 2) + x = x.reshape(in_shape) + x_glu, x_linear = x.chunk(2, dim=-1) + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + y = out_glu * (x_linear + 1) + y_ref = scales_ref.unsqueeze(-1) * y + if input_requires_grad or scales_requires_grad: + y_ref.backward(dy_ref) + + op = te_ops.ScaledClampedQGeGLU( + glu_interleave_size=glu_interleave_size, + limit=limit, + alpha=alpha, + ) + y_test = op(x_test, scales_test) + if input_requires_grad or scales_requires_grad: + y_test.backward(dy_test) + + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + assert_close(y_test, y_ref, **tols) + assert_close_grads(x_test, x_ref, **tols) + assert_close_grads(scales_test, scales_ref, **tols) + + def test_interleaved_scaled_clamped_qgeglu(self): + """ScaledClampedQGeGLU with block interleaved input format""" + self.test_scaled_clamped_qgeglu( + in_shape=(32, 192), + glu_interleave_size=32, + input_requires_grad=True, + scales_requires_grad=True, + ) + class TestFusedOps: """Tests for fused operations""" @@ -3249,6 +3337,7 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) @pytest.mark.parametrize("glu_interleave_size", (None, 32)) @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) + @pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu")) def test_grouped_mlp( self, *, @@ -3264,8 +3353,9 @@ def test_grouped_mlp( split_alignment: int = 256, glu_interleave_size: Optional[int], delay_wgrad_compute: bool, + activation: str, ) -> None: - """GroupedLinear + ScaledSwiGLU + GroupedLinear""" + """GroupedLinear + ScaledSwiGLU / ScaledClampedQGeGLU + GroupedLinear""" # Split sizes split_sizes = [split_alignment * (i) for i in range(group_size)] @@ -3288,6 +3378,9 @@ def test_grouped_mlp( if quantization == "mxfp8" and bias: # Will be supported in future CUDNN release. pytest.skip("Bias/dbias not yet supported in MXFP8 fused grouped MLP") + if quantization == "nvfp4" and activation == "scaled_clamped_qgeglu" and bias: + # TODO: ksivaman: Need to debug numerics for this case. + pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -3376,7 +3469,14 @@ def test_grouped_mlp( x = x.transpose(1, 2) x = x.reshape(-1, 2 * hidden_size) x1, x2 = x.chunk(2, dim=-1) - x = torch.nn.functional.silu(x1) * x2 + if activation == "scaled_swiglu": + x = torch.nn.functional.silu(x1) * x2 + else: + lim = torch.tensor(7.0, device=x1.device, dtype=x1.dtype) + geglu_alpha = 1.702 + x1c = torch.minimum(x1, lim) + x2c = torch.clamp(x2, -lim, lim) + x = (x2c + 1) * (x1c * torch.sigmoid(geglu_alpha * x1c)) x = x * probs[group_idx].unsqueeze(-1) x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx], bias=fc2_bs_ref[group_idx]) ys.append(x) @@ -3385,6 +3485,11 @@ def test_grouped_mlp( # Construct operations recipe = make_recipe(quantization) + scaled_act = ( + te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + if activation == "scaled_swiglu" + else te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + ) with te.quantized_model_init(enabled=with_quantization, recipe=recipe): fc1 = te_ops.GroupedLinear( group_size, @@ -3412,7 +3517,7 @@ def test_grouped_mlp( ) module = te_ops.Sequential( fc1, - te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size), + scaled_act, fc2, ) @@ -3484,6 +3589,10 @@ def test_grouped_mlp( quantization == "mxfp8" and dtype in (torch.bfloat16, torch.float16) and glu_interleave_size == 32 + and ( + activation != "scaled_clamped_qgeglu" + or _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu() + ) ): if te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): forward_ops = module._module_groups[0]._forward_ops @@ -3572,6 +3681,7 @@ def test_grouped_mlp( @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("single_grouped_weight", (False, True)) @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) + @pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu")) @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) def test_grouped_mlp_cuda_graph_safe_mxfp8( self, @@ -3579,6 +3689,7 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8( dtype: torch.dtype, single_grouped_weight: bool, accumulate_into_main_grad: bool, + activation: str, device: torch.device = "cuda", group_size: int = 4, hidden_size: int = 256, @@ -3591,6 +3702,12 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8( pytest.skip("MXFP8 fused grouped MLP is not supported on this system") if dtype not in (torch.bfloat16, torch.float16): pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16") + if activation == "scaled_clamped_qgeglu" and not ( + _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu() + ): + pytest.skip( + "ScaledClampedQGeGLU fused grouped MLP requires nvidia-cudnn-frontend >= 1.23.0" + ) split_sizes = [split_alignment * (i + 1) for i in range(group_size)] random.shuffle(split_sizes) @@ -3619,9 +3736,14 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8( single_grouped_weight=single_grouped_weight, accumulate_into_main_grad=accumulate_into_main_grad, ) + scaled_act = ( + te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + if activation == "scaled_swiglu" + else te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + ) module = te_ops.Sequential( fc1, - te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size), + scaled_act, fc2, ) diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 0e03e691f3..ae8b48a90d 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -5,9 +5,12 @@ """Helper functions used in fusible operations.""" from __future__ import annotations +import functools +from importlib.metadata import PackageNotFoundError, version as get_pkg_version from typing import Optional import torch +from packaging.version import Version as PkgVersion from transformer_engine_torch import FP8TensorMeta from ..torch_version import torch_version @@ -17,6 +20,15 @@ from ..utils import canonicalize_dtype +@functools.lru_cache(maxsize=1) +def _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu() -> bool: + """Check cuDNN FE min version with fixed numerics for qgeglu.""" + try: + return PkgVersion(get_pkg_version("nvidia-cudnn-frontend")) >= PkgVersion("1.23.0") + except PackageNotFoundError: + return False + + def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorStorage) -> bool: """Check if tensor is a quantized tensor""" return isinstance(tensor, QuantizedTensorStorage) @@ -73,8 +85,8 @@ def get_fp8_meta_from_fp8_tensor(tensor: Float8Tensor) -> tuple[FP8TensorMeta, i return fp8_meta, 0 -def validate_grouped_mlp_dims(fc1, swiglu, fc2) -> None: - """Validate FC1/SwiGLU/FC2 dimensions and interleave size for fused grouped MLP.""" +def validate_grouped_mlp_dims(fc1, glu_op, fc2) -> None: + """Validate FC1 / scaled GLU / FC2 dimensions for fused grouped MLP.""" if fc1.in_features % 256 != 0 or fc1.out_features % 256 != 0: raise ValueError( @@ -93,10 +105,10 @@ def validate_grouped_mlp_dims(fc1, swiglu, fc2) -> None: f"and FC2 (num_groups={fc2.num_groups}, in_features={fc2.in_features}, " f"out_features={fc2.out_features}) do not match." ) - if swiglu.glu_interleave_size != 32: + if glu_op.glu_interleave_size != 32: raise ValueError( "Fused kernel requires 32-wide GLU interleaving, " - f"but got glu_interleave_size={swiglu.glu_interleave_size}." + f"but got glu_interleave_size={glu_op.glu_interleave_size}." ) @@ -106,7 +118,7 @@ def fuse_grouped_mlp_ops( recipe, fused_op_cls, ): - """Sliding-window fusion for GroupedLinear + ScaledSwiGLU + GroupedLinear. + """Sliding-window fusion for GroupedLinear + scaled GLU + GroupedLinear. Parameters ---------- @@ -116,7 +128,9 @@ def fuse_grouped_mlp_ops( Quantization recipe. fused_op_cls : type Fused operation class with ``is_supported()`` classmethod and - constructor accepting ``fc1``, ``swiglu``, ``fc2`` keyword args. + constructor accepting ``fc1``, ``glu_op``, ``fc2`` keyword args. The + ``glu_op`` must be :class:`~transformer_engine.pytorch.ops.basic.swiglu.ScaledSwiGLU` + or :class:`~transformer_engine.pytorch.ops.basic.swiglu.ScaledClampedQGeGLU`. May also expose ``is_fc1_bias_supported()`` and/or ``is_fc2_bias_supported()`` classmethods for bias eligibility. @@ -125,7 +139,11 @@ def fuse_grouped_mlp_ops( list of FusibleOperation Updated operations with matched triples replaced by fused ops. """ - from .basic import GroupedLinear, ScaledSwiGLU # pylint: disable=import-outside-toplevel + from .basic import ( # pylint: disable=import-outside-toplevel + GroupedLinear, + ScaledClampedQGeGLU, + ScaledSwiGLU, + ) if not fused_op_cls.is_supported(): return ops @@ -146,10 +164,15 @@ def fuse_grouped_mlp_ops( matches_pattern = True if not ( isinstance(window[0], GroupedLinear) - and isinstance(window[1], ScaledSwiGLU) + and isinstance(window[1], (ScaledSwiGLU, ScaledClampedQGeGLU)) and isinstance(window[2], GroupedLinear) ): matches_pattern = False + elif isinstance(window[1], ScaledClampedQGeGLU) and ( + abs(window[1]._clamped.alpha - 1.702) > 0.001 + or not _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu() + ): + matches_pattern = False elif window[0].num_groups != window[2].num_groups: matches_pattern = False elif ( diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index e0a3f41019..45c938ede8 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -32,4 +32,4 @@ from .reduce_scatter import ReduceScatter from .reshape import Reshape from .rmsnorm import RMSNorm -from .swiglu import ClampedSwiGLU, ScaledSwiGLU, SwiGLU +from .swiglu import ClampedSwiGLU, ScaledClampedQGeGLU, ScaledSwiGLU, SwiGLU diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py index b4427df41a..9c0bc86bc1 100644 --- a/transformer_engine/pytorch/ops/basic/swiglu.py +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -17,7 +17,7 @@ from ..op import BasicOperation, OperationContext from .._common import maybe_dequantize -__all__ = ["SwiGLU", "ClampedSwiGLU", "ScaledSwiGLU"] +__all__ = ["SwiGLU", "ClampedSwiGLU", "ScaledSwiGLU", "ScaledClampedQGeGLU"] class SwiGLU(BasicOperation): @@ -231,6 +231,34 @@ def __init__( self.cache_quantized_input: bool = cache_quantized_input self.glu_interleave_size: Optional[int] = glu_interleave_size + def _tex_clamped_swiglu_forward( + self, + swiglu_in: torch.Tensor, + next_op_input_quantizer: Optional[Quantizer], + ) -> torch.Tensor: + """Call :func:`tex.clamped_swiglu` with this op's ``limit`` / ``alpha``.""" + return tex.clamped_swiglu( + swiglu_in, + next_op_input_quantizer, + self.limit, + self.alpha, + ) + + def _tex_clamped_dswiglu( + self, + dy: torch.Tensor, + swiglu_in: torch.Tensor, + quantizer: Optional[Quantizer], + ) -> torch.Tensor: + """Call :func:`tex.clamped_dswiglu` with this op's ``limit`` / ``alpha``.""" + return tex.clamped_dswiglu( + dy, + swiglu_in, + quantizer, + self.limit, + self.alpha, + ) + def op_forward( self, ctx: OperationContext, @@ -252,7 +280,7 @@ def op_forward( x = maybe_dequantize(input_.contiguous(), dtype) # Remove interleaving if needed - swiglu_in = input_ + swiglu_in = x if self.glu_interleave_size is not None: shape = swiglu_in.size() swiglu_in = swiglu_in.reshape( @@ -265,12 +293,7 @@ def op_forward( swiglu_in = swiglu_in.view(shape) # Launch kernel - out = tex.clamped_swiglu( - swiglu_in, - next_op_input_quantizer, - limit=self.limit, - alpha=self.alpha, - ) + out = self._tex_clamped_swiglu_forward(swiglu_in, next_op_input_quantizer) # Quantize input to FP8 before caching if needed if self.cache_quantized_input: @@ -320,13 +343,7 @@ def op_backward( quantizer = None # Launch kernel - grad_swiglu_in = tex.clamped_dswiglu( - dy, - swiglu_in, - quantizer, - limit=self.limit, - alpha=self.alpha, - ) + grad_swiglu_in = self._tex_clamped_dswiglu(dy, swiglu_in, quantizer) # Apply interleaving if needed dx = grad_swiglu_in @@ -347,29 +364,25 @@ def op_backward( return dx, () -class ScaledSwiGLU(BasicOperation): - r"""SwiGLU with post-scaling. +class _ScaledGLU(BasicOperation): + """SwiGLU-family activation with per-row scales (fused grouped MLP middle op).""" - If the SwiGLU output has shape ``(d_1, ..., d_n)``, it is - multiplied with an extra input tensor of shape - ``(d_1, ..., d_{n-1})``. - - Parameters - ---------- - glu_interleave_size : int, optional - When set, the GLU activations will use an experimental block - interleaved format. See the corresponding option in the SwiGLU - operation for more details. - - """ - - # Operation expects scales num_extra_inputs: int = 1 - def __init__(self, glu_interleave_size: Optional[int] = None): + def __init__(self, glu_interleave_size: Optional[int] = None) -> None: super().__init__() self.glu_interleave_size: Optional[int] = glu_interleave_size + def _glu_forward(self, swiglu_in: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def _glu_backward( + self, + grad_swiglu_out: torch.Tensor, + swiglu_in: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + def op_forward(self, *args, **kwargs) -> None: raise RuntimeError( f"{self.__class__.__name__} operation has " @@ -423,8 +436,7 @@ def fuser_forward( swiglu_in = swiglu_in.transpose(1, 2).contiguous() swiglu_in = swiglu_in.view(shape) - # Compute scaled SwiGLU - swiglu_out = tex.swiglu(swiglu_in, None) + swiglu_out = self._glu_forward(swiglu_in) out = swiglu_out * scales.unsqueeze(-1) # Save state for backward pass @@ -477,7 +489,7 @@ def fuser_backward( grad_input = None if ctx.input_requires_grad: grad_swiglu_out = grad_output * scales.unsqueeze(-1) - grad_swiglu_in = tex.dswiglu(grad_swiglu_out, swiglu_in, None) + grad_swiglu_in = self._glu_backward(grad_swiglu_out, swiglu_in) grad_input = grad_swiglu_in if self.glu_interleave_size is not None: shape = grad_input.size() @@ -490,13 +502,87 @@ def fuser_backward( grad_input = grad_input.transpose(1, 2).contiguous() grad_input = grad_input.view(shape) - # Compute scales grad by recomputing SwiGLU + # Compute scales grad by recomputing GLU grad_extra_input = None if ctx.extra_input_requires_grad: - swiglu_out = tex.swiglu(swiglu_in, None) + swiglu_out = self._glu_forward(swiglu_in) grad_extra_input = torch.linalg.vecdot(swiglu_out, grad_output) # Clear input tensor if possible clear_tensor_data(ctx.saved_tensors[0]) # input_ return grad_input, [()], [(grad_extra_input,)] + + +class ScaledSwiGLU(_ScaledGLU): + r"""SwiGLU with post-scaling (matches cuDNN grouped GEMM ``act_func="swiglu"``). + + If the GLU output has shape ``(d_1, ..., d_n)``, it is multiplied + with an extra input tensor of shape ``(d_1, ..., d_{n-1})``. + + Parameters + ---------- + glu_interleave_size : int, optional + When set, the GLU activations will use an experimental block + interleaved format. See the corresponding option in the SwiGLU + operation for more details. + + """ + + def _glu_forward(self, swiglu_in: torch.Tensor) -> torch.Tensor: + return tex.swiglu(swiglu_in, None) + + def _glu_backward( + self, + grad_swiglu_out: torch.Tensor, + swiglu_in: torch.Tensor, + ) -> torch.Tensor: + return tex.dswiglu(grad_swiglu_out, swiglu_in, None) + + +class ScaledClampedQGeGLU(_ScaledGLU): + r"""Clamped QGeGLU with post-scaling + (matches cuDNN grouped GEMM ``act_func="geglu"``). + + Same layout and scaling contract as :class:`ScaledSwiGLU`, but the GLU + uses :class:`ClampedSwiGLU` numerics (default ``limit`` / ``alpha`` match + cuDNN). + + Parameters + ---------- + glu_interleave_size : int, optional + When set, the GLU activations will use an experimental block + interleaved format. See :class:`ClampedSwiGLU`. + limit : float, default ``7.0`` + Clamp limit (see :class:`ClampedSwiGLU`). + alpha : float, default ``1.702`` + Sigmoid scale (see :class:`ClampedSwiGLU`). + + """ + + def __init__( + self, + glu_interleave_size: Optional[int] = None, + *, + limit: float = 7.0, + alpha: float = 1.702, + ) -> None: + super().__init__(glu_interleave_size) + self._clamped: ClampedSwiGLU = ClampedSwiGLU( + limit=limit, + alpha=alpha, + ) + + def _glu_forward(self, swiglu_in: torch.Tensor) -> torch.Tensor: + return self._clamped._tex_clamped_swiglu_forward(swiglu_in, None) + + def _glu_backward( + self, + grad_swiglu_out: torch.Tensor, + swiglu_in: torch.Tensor, + ) -> torch.Tensor: + return self._clamped._tex_clamped_dswiglu( + grad_swiglu_out, + swiglu_in, + None, + ) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index a821258ebf..6b452b0182 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -24,7 +24,7 @@ from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...utils import clear_tensor_data, get_cached_ones_tensor, get_device_compute_capability from ...constants import MXFP8_BLOCK_SCALING_SIZE -from ..basic import GroupedLinear, ScaledSwiGLU +from ..basic import GroupedLinear, ScaledClampedQGeGLU, ScaledSwiGLU from ..fuser import register_backward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext from .._common import ( @@ -181,7 +181,7 @@ def _compute_grad_params( class BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8(FusedOperation): - """Fused op for MXFP8 GroupedLinear + ScaledSwiGLU + GroupedLinear + """Fused op for MXFP8 GroupedLinear + ScaledSwiGLU or ScaledClampedQGeGLU + GroupedLinear Uses experimental CuTe DSL kernel from cuDNN front-end. @@ -229,7 +229,7 @@ def __init__( self, *, fc1: GroupedLinear, - swiglu: ScaledSwiGLU, + swiglu: ScaledSwiGLU | ScaledClampedQGeGLU, fc2: GroupedLinear, ) -> None: super().__init__((fc1, swiglu, fc2)) @@ -237,6 +237,11 @@ def __init__( self.grouped_gemm_dglu_kernel() # Try triggering import error raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.") validate_grouped_mlp_dims(fc1, swiglu, fc2) + # The cuDNN dgeglu implementation corresponds to ScaledClampedQGeGLU. + # The act_func string should be fixed on the cuDNN FE side. + self._cudnn_dact_func: str = ( + "dgeglu" if isinstance(swiglu, ScaledClampedQGeGLU) else "dswiglu" + ) def fuser_backward( self, @@ -433,7 +438,7 @@ def fuser_backward( "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, "current_stream": current_stream, "discrete_col_sfd": True, - "act_func": "dswiglu", + "act_func": self._cudnn_dact_func, "use_dynamic_sched": True, } diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index c5ce2b148d..afabec8392 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -20,7 +20,7 @@ from ...tensor.grouped_tensor import GroupedTensor from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...constants import MXFP8_BLOCK_SCALING_SIZE -from ..basic import GroupedLinear, ScaledSwiGLU +from ..basic import GroupedLinear, ScaledClampedQGeGLU, ScaledSwiGLU from ..fuser import register_forward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext from .._common import ( @@ -46,7 +46,7 @@ def _pack_grouped_linear_bias_for_cudnn(linear_op: GroupedLinear) -> Optional[to class ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8(FusedOperation): - """Fused op for MXFP8 GroupedLinear + ScaledSwiGLU + GroupedLinear + """Fused op for MXFP8 GroupedLinear + scaled GLU + GroupedLinear Uses experimental CuTe DSL kernel from cuDNN front-end. @@ -123,7 +123,7 @@ def __init__( self, *, fc1: GroupedLinear, - swiglu: ScaledSwiGLU, + swiglu: ScaledSwiGLU | ScaledClampedQGeGLU, fc2: GroupedLinear, ) -> None: super().__init__((fc1, swiglu, fc2)) @@ -131,6 +131,9 @@ def __init__( self.grouped_gemm_glu_kernel() # Try triggering import error raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.") validate_grouped_mlp_dims(fc1, swiglu, fc2) + # The cuDNN geglu implementation corresponds to ScaledClampedQGeGLU. + # The act_func string should be fixed on the cuDNN FE side. + self._cudnn_act_func: str = "geglu" if isinstance(swiglu, ScaledClampedQGeGLU) else "swiglu" def fuser_forward( self, @@ -339,7 +342,7 @@ def fuser_forward( "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, "current_stream": current_stream, "discrete_col_sfd": True, - "act_func": "swiglu", + "act_func": self._cudnn_act_func, "use_dynamic_sched": True, }