diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 795cbf3452..8e901d7c5a 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -76,6 +76,13 @@ _quantization_list.append("nvfp4") +@pytest.fixture(autouse=True, scope="class") +def _reset_rng_states_per_test(): + """Restore torch, CUDA, and Python ``random`` before each test in this module.""" + reset_rng_states() + yield + + def maybe_skip_quantization( quantization: Optional[str], *, @@ -363,10 +370,6 @@ def test_extra_tensors(self, size: int = 16) -> None: class TestFuser: """Tests for operation fusion infrastructure""" - @staticmethod - def setup_class(cls) -> None: - reset_rng_states() - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_scale_update( self, @@ -579,10 +582,6 @@ def test_pyt_autocast( class TestBasicOps: """Tests for individual operations""" - @staticmethod - def setup_class(cls) -> None: - reset_rng_states() - @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("quantization", _quantization_list) @@ -2326,10 +2325,6 @@ def test_interleaved_scaled_clamped_qgeglu(self): class TestFusedOps: """Tests for fused operations""" - @staticmethod - def setup_class(cls) -> None: - reset_rng_states() - @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5))) @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1))) @pytest.mark.parametrize("dtype", _dtypes) @@ -3034,10 +3029,6 @@ def test_backward_linear_scale( class TestCheckpointing: """Tests for checkpointing""" - @staticmethod - def setup_class(cls) -> None: - reset_rng_states() - @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_weight", (False, True)) def test_linear( @@ -3150,10 +3141,6 @@ def test_linear( class TestSequentialModules: """Test for larger Sequentials with modules commonly used together""" - @staticmethod - def setup_class(cls) -> None: - reset_rng_states() - @pytest.mark.parametrize("requires_grad", (False, True)) @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("quantized_compute", (False, True)) @@ -3337,13 +3324,14 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) @pytest.mark.parametrize("glu_interleave_size", (None, 32)) @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) + @pytest.mark.parametrize("hidden_size", (128, 256)) @pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu")) def test_grouped_mlp( self, *, group_size: int = 4, bias: bool, - hidden_size: int = 256, + hidden_size: int, dtype: torch.dtype, quantization: Optional[str], single_grouped_weight: bool, diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 196ae8c165..fd9a6416ec 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -6,6 +6,7 @@ import logging import os +import random import subprocess from contextlib import contextmanager from typing import Optional, Sequence, Tuple, Dict, Any, List @@ -173,8 +174,8 @@ def skip_unsupported_backward_override( pytest.skip(f"{layer_type} does not support NVTE_BACKWARD_OVERRIDE={backward_override}.") -# Cached RNG state -_rng_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None +# Cached RNG state (torch CPU, torch CUDA, Python ``random``) +_rng_states: Optional[Tuple[torch.Tensor, torch.Tensor, Any]] = None def reset_rng_states() -> None: @@ -183,11 +184,17 @@ def reset_rng_states() -> None: if _rng_states is None: torch.manual_seed(1234) torch.cuda.manual_seed(1234) - _rng_states = (torch.get_rng_state(), torch.cuda.get_rng_state()) + random.seed(1234) + _rng_states = ( + torch.get_rng_state(), + torch.cuda.get_rng_state(), + random.getstate(), + ) else: - cpu_rng_state, cuda_rng_state = _rng_states + cpu_rng_state, cuda_rng_state, random_state = _rng_states torch.set_rng_state(cpu_rng_state) torch.cuda.set_rng_state(cuda_rng_state) + random.setstate(random_state) def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8): diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index ae8b48a90d..15dc17e812 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -88,12 +88,12 @@ def get_fp8_meta_from_fp8_tensor(tensor: Float8Tensor) -> tuple[FP8TensorMeta, i def validate_grouped_mlp_dims(fc1, glu_op, fc2) -> None: """Validate FC1 / scaled GLU / FC2 dimensions for fused grouped MLP.""" - if fc1.in_features % 256 != 0 or fc1.out_features % 256 != 0: + if fc1.in_features % 64 != 0 or fc1.out_features % 64 != 0: raise ValueError( f"Unsupported dims for FC1 (num_groups={fc1.num_groups}, " f"in_features={fc1.in_features}, out_features={fc1.out_features})." ) - if fc2.in_features % 256 != 0 or fc2.out_features % 256 != 0: + if fc2.in_features % 64 != 0 or fc2.out_features % 64 != 0: raise ValueError( f"Unsupported dims for FC2 (num_groups={fc2.num_groups}, " f"in_features={fc2.in_features}, out_features={fc2.out_features})." @@ -176,10 +176,10 @@ def fuse_grouped_mlp_ops( elif window[0].num_groups != window[2].num_groups: matches_pattern = False elif ( - window[0].in_features % 256 != 0 - or window[0].out_features % 256 != 0 - or window[2].in_features % 256 != 0 - or window[2].out_features % 256 != 0 + window[0].in_features % 64 != 0 + or window[0].out_features % 64 != 0 + or window[2].in_features % 64 != 0 + or window[2].out_features % 64 != 0 ): matches_pattern = False elif window[1].glu_interleave_size != 32: diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 6b452b0182..6c8b251cc8 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -406,8 +406,8 @@ def fuser_backward( fc2_dy_scales = fc2_dy_scales.view(dtype=torch.float8_e8m0fnu) fc2_dy_scales = fc2_dy_scales.view( 1, - out_shape[0] // 128, - out_shape[1] // 128, + (out_shape[0] + 127) // 128, + (out_shape[1] + 127) // 128, MXFP8_BLOCK_SCALING_SIZE, 4, 4, @@ -456,8 +456,8 @@ def fuser_backward( fc2_w_scales = fc2_weight_for_gemm.columnwise_scale_inv.view(dtype=torch.float8_e8m0fnu) fc2_w_scales = fc2_w_scales.view( num_groups, - fc2_weight_shape[1] // 128, - fc2_weight_shape[0] // 128, + (fc2_weight_shape[1] + 127) // 128, + (fc2_weight_shape[0] + 127) // 128, MXFP8_BLOCK_SCALING_SIZE, 4, 4, @@ -587,8 +587,8 @@ def fuser_backward( ) fc1_w_scales = fc1_w_scales.view( num_groups, - fc1_weight_shape[1] // 128, - fc1_weight_shape[0] // 128, + (fc1_weight_shape[1] + 127) // 128, + (fc1_weight_shape[0] + 127) // 128, MXFP8_BLOCK_SCALING_SIZE, 4, 4, diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index afabec8392..69096f2e39 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -154,6 +154,7 @@ def fuser_forward( fc2_weight_shape = (fc2_op.out_features, fc2_op.in_features) input_ = input_.reshape(-1, fc1_weight_shape[1]) in_shape = list(input_.size()) + assert in_shape[0] % 128 == 0, "Unsupported input shape for fused grouped MLP." num_groups = fc1_op.num_groups fc1_weight_param = fc1_op.weight if fc1_op.single_grouped_weight else fc1_op.weight0 @@ -312,8 +313,8 @@ def fuser_forward( fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu) fc1_x_scales = fc1_x_scales.view( 1, - in_shape[0] // 128, - in_shape[1] // 128, + (in_shape[0] + 127) // 128, + (in_shape[1] + 127) // 128, MXFP8_BLOCK_SCALING_SIZE, 4, 4, @@ -361,8 +362,8 @@ def fuser_forward( fc1_w_scales = fc1_weight_for_gemm.scale_inv.view(dtype=torch.float8_e8m0fnu) fc1_w_scales = fc1_w_scales.view( num_groups, - fc1_weight_shape[0] // 128, - fc1_weight_shape[1] // 128, + (fc1_weight_shape[0] + 127) // 128, + (fc1_weight_shape[1] + 127) // 128, MXFP8_BLOCK_SCALING_SIZE, 4, 4, @@ -458,8 +459,8 @@ def fuser_forward( fc2_w_scales = fc2_weight_for_gemm.scale_inv.view(dtype=torch.float8_e8m0fnu) fc2_w_scales = fc2_w_scales.view( num_groups, - fc2_weight_shape[0] // 128, - fc2_weight_shape[1] // 128, + (fc2_weight_shape[0] + 127) // 128, + (fc2_weight_shape[1] + 127) // 128, MXFP8_BLOCK_SCALING_SIZE, 4, 4,