-
Notifications
You must be signed in to change notification settings - Fork 688
[PyTorch] Support scaled + clamped SwiGLU in te.ops and enable fused MXFP8 grouped MLP
#2855
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
8227ae3
ce405a9
0fe9522
53f5ca9
9c411e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2234,6 +2234,91 @@ def test_interleaved_scaled_swiglu(self): | |
| scales_requires_grad=True, | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128))) | ||
| @pytest.mark.parametrize("input_requires_grad", (False, True)) | ||
| @pytest.mark.parametrize("scales_requires_grad", (False, True)) | ||
| def test_scaled_clamped_swiglu( | ||
| self, | ||
| *, | ||
| in_shape: Iterable[int], | ||
| glu_interleave_size: Optional[int] = None, | ||
| dtype: torch.dtype = torch.float32, | ||
| device: torch.device = "cuda", | ||
| input_requires_grad: bool, | ||
| scales_requires_grad: bool, | ||
| limit: float = 7.0, | ||
| alpha: float = 1.702, | ||
| ) -> None: | ||
| """Clamped SwiGLU with post-scale""" | ||
|
|
||
| # Tensor dims | ||
| out_shape = list(in_shape) | ||
| out_shape[-1] //= 2 | ||
|
|
||
| # Random data | ||
| x_ref, x_test = make_reference_and_test_tensors( | ||
| in_shape, | ||
| test_dtype=dtype, | ||
| test_device=device, | ||
| requires_grad=input_requires_grad, | ||
| ) | ||
| scales_ref, scales_test = make_reference_and_test_tensors( | ||
| in_shape[:-1], | ||
| test_dtype=dtype, | ||
| test_device=device, | ||
| requires_grad=scales_requires_grad, | ||
| ) | ||
| dy_ref, dy_test = make_reference_and_test_tensors( | ||
| out_shape, | ||
| test_dtype=dtype, | ||
| test_device=device, | ||
| requires_grad=False, | ||
| ) | ||
|
|
||
| # Plain PyTorch reference (matches :class:`ClampedSwiGLU` numerics) | ||
| x = x_ref | ||
| if glu_interleave_size is not None: | ||
| x = x.reshape( | ||
| -1, | ||
| in_shape[-1] // (2 * glu_interleave_size), | ||
| 2, | ||
| glu_interleave_size, | ||
| ) | ||
| x = x.transpose(1, 2) | ||
| x = x.reshape(in_shape) | ||
| x_glu, x_linear = x.chunk(2, dim=-1) | ||
| x_glu = x_glu.clamp(min=None, max=limit) | ||
| x_linear = x_linear.clamp(min=-limit, max=limit) | ||
| out_glu = x_glu * torch.sigmoid(alpha * x_glu) | ||
| y = out_glu * (x_linear + 1) | ||
| y_ref = scales_ref.unsqueeze(-1) * y | ||
| if input_requires_grad or scales_requires_grad: | ||
| y_ref.backward(dy_ref) | ||
|
|
||
| op = te_ops.ScaledClampedSwiGLU( | ||
| glu_interleave_size=glu_interleave_size, | ||
| limit=limit, | ||
| alpha=alpha, | ||
| ) | ||
| y_test = op(x_test, scales_test) | ||
| if input_requires_grad or scales_requires_grad: | ||
| y_test.backward(dy_test) | ||
|
|
||
| tols = dtype_tols(dtype) | ||
| y_test = y_test.to(dtype=torch.float64, device="cpu") | ||
| assert_close(y_test, y_ref, **tols) | ||
| assert_close_grads(x_test, x_ref, **tols) | ||
| assert_close_grads(scales_test, scales_ref, **tols) | ||
|
|
||
| def test_interleaved_scaled_clamped_swiglu(self): | ||
| """Clamped SwiGLU with post-scale and block interleaved input format""" | ||
| self.test_scaled_clamped_swiglu( | ||
| in_shape=(32, 192), | ||
| glu_interleave_size=32, | ||
| input_requires_grad=True, | ||
| scales_requires_grad=True, | ||
| ) | ||
|
|
||
|
|
||
| class TestFusedOps: | ||
| """Tests for fused operations""" | ||
|
|
@@ -3249,6 +3334,7 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: | |
| @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) | ||
| @pytest.mark.parametrize("glu_interleave_size", (None, 32)) | ||
| @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) | ||
| @pytest.mark.parametrize("activation", ("swiglu", "geglu")) | ||
| def test_grouped_mlp( | ||
| self, | ||
| *, | ||
|
|
@@ -3264,8 +3350,9 @@ def test_grouped_mlp( | |
| split_alignment: int = 256, | ||
| glu_interleave_size: Optional[int], | ||
| delay_wgrad_compute: bool, | ||
| activation: str, | ||
| ) -> None: | ||
| """GroupedLinear + ScaledSwiGLU + GroupedLinear""" | ||
| """GroupedLinear + ScaledSwiGLU / ScaledClampedSwiGLU + GroupedLinear""" | ||
|
|
||
| # Split sizes | ||
| split_sizes = [split_alignment * (i) for i in range(group_size)] | ||
|
|
@@ -3288,6 +3375,9 @@ def test_grouped_mlp( | |
| if quantization == "mxfp8" and bias: | ||
| # Will be supported in future CUDNN release. | ||
| pytest.skip("Bias/dbias not yet supported in MXFP8 fused grouped MLP") | ||
| if quantization == "nvfp4" and activation == "geglu" and bias: | ||
| # TODO: ksivaman: Need to debug numerics for this case. | ||
| pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") | ||
|
||
|
|
||
| # Random data | ||
| x_ref, x_test = make_reference_and_test_tensors( | ||
|
|
@@ -3376,7 +3466,14 @@ def test_grouped_mlp( | |
| x = x.transpose(1, 2) | ||
| x = x.reshape(-1, 2 * hidden_size) | ||
| x1, x2 = x.chunk(2, dim=-1) | ||
| x = torch.nn.functional.silu(x1) * x2 | ||
| if activation == "swiglu": | ||
| x = torch.nn.functional.silu(x1) * x2 | ||
| else: | ||
| lim = torch.tensor(7.0, device=x1.device, dtype=x1.dtype) | ||
| geglu_alpha = 1.702 | ||
| x1c = torch.minimum(x1, lim) | ||
| x2c = torch.clamp(x2, -lim, lim) | ||
| x = (x2c + 1) * (x1c * torch.sigmoid(geglu_alpha * x1c)) | ||
| x = x * probs[group_idx].unsqueeze(-1) | ||
| x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx], bias=fc2_bs_ref[group_idx]) | ||
| ys.append(x) | ||
|
|
@@ -3385,6 +3482,11 @@ def test_grouped_mlp( | |
|
|
||
| # Construct operations | ||
| recipe = make_recipe(quantization) | ||
| scaled_act = ( | ||
| te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) | ||
| if activation == "swiglu" | ||
| else te_ops.ScaledClampedSwiGLU(glu_interleave_size=glu_interleave_size) | ||
| ) | ||
| with te.quantized_model_init(enabled=with_quantization, recipe=recipe): | ||
| fc1 = te_ops.GroupedLinear( | ||
| group_size, | ||
|
|
@@ -3412,7 +3514,7 @@ def test_grouped_mlp( | |
| ) | ||
| module = te_ops.Sequential( | ||
| fc1, | ||
| te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size), | ||
| scaled_act, | ||
| fc2, | ||
| ) | ||
|
|
||
|
|
@@ -3572,13 +3674,15 @@ def test_grouped_mlp( | |
| @pytest.mark.parametrize("dtype", _dtypes) | ||
| @pytest.mark.parametrize("single_grouped_weight", (False, True)) | ||
| @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) | ||
| @pytest.mark.parametrize("activation", ("swiglu", "geglu")) | ||
| @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) | ||
| def test_grouped_mlp_cuda_graph_safe_mxfp8( | ||
| self, | ||
| *, | ||
| dtype: torch.dtype, | ||
| single_grouped_weight: bool, | ||
| accumulate_into_main_grad: bool, | ||
| activation: str, | ||
| device: torch.device = "cuda", | ||
| group_size: int = 4, | ||
| hidden_size: int = 256, | ||
|
|
@@ -3619,9 +3723,14 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8( | |
| single_grouped_weight=single_grouped_weight, | ||
| accumulate_into_main_grad=accumulate_into_main_grad, | ||
| ) | ||
| scaled_act = ( | ||
| te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) | ||
| if activation == "swiglu" | ||
| else te_ops.ScaledClampedSwiGLU(glu_interleave_size=glu_interleave_size) | ||
| ) | ||
| module = te_ops.Sequential( | ||
| fc1, | ||
| te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size), | ||
| scaled_act, | ||
| fc2, | ||
| ) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.