diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 601149916b..85068f6536 100755 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,6 +31,22 @@ repos: args: [--line-length=100, --preview, --enable-unstable-feature=string_processing] types: [python] + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.4 + hooks: + - id: ruff + name: Lint python code + types: [python] + files: ^transformer_engine/ + + - repo: https://github.com/cpplint/cpplint + rev: '1.6.0' + hooks: + - id: cpplint + types_or: [c, c++, cuda] + files: ^transformer_engine/(common|jax|pytorch)/ + exclude: ^transformer_engine/build_tools/build/ + - repo: https://github.com/pre-commit/mirrors-clang-format rev: v18.1.6 hooks: diff --git a/pylintrc b/pylintrc deleted file mode 100644 index 50f85fad9d..0000000000 --- a/pylintrc +++ /dev/null @@ -1,38 +0,0 @@ -[MASTER] -extension-pkg-whitelist=flash_attn_2_cuda, - torch, - transformer_engine_torch, - transformer_engine_jax - -disable=too-many-locals, - too-few-public-methods, - too-many-public-methods, - too-many-positional-arguments, - invalid-name, - too-many-arguments, - abstract-method, - arguments-differ, - too-many-instance-attributes, - unsubscriptable-object, - import-outside-toplevel, - too-many-statements, - import-error, - too-many-lines, - use-maxsplit-arg, - protected-access, - pointless-string-statement, - cyclic-import, - duplicate-code, - no-member, - attribute-defined-outside-init, - global-statement, - too-many-branches, - global-variable-not-assigned, - redefined-argument-from-local, - line-too-long, - too-many-return-statements, - too-many-nested-blocks - -[TYPECHECK] -ignored-modules=torch -ignored-classes=torch diff --git a/pyproject.toml b/pyproject.toml index 4a8fded172..7fb4e50afd 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,3 +7,40 @@ requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "nin # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" + +[tool.ruff] +line-length = 100 + +[tool.ruff.format] +preview = true +docstring-code-format = true + +[tool.ruff.lint] +select = ["E", "F", "W", "PL"] +ignore = [ + "E402", # module-level-import-not-at-top (pylint import-outside-toplevel was disabled) + "E501", # line-too-long + "E731", # lambda-assignment + "E741", # ambiguous-variable-name (pylint invalid-name was disabled) + "PLR0904", # too-many-public-methods + "PLR0911", # too-many-return-statements + "PLR0912", # too-many-branches + "PLR0913", # too-many-arguments + "PLR0914", # too-many-locals + "PLR0915", # too-many-statements + "PLR0917", # too-many-positional-arguments + "PLR1702", # too-many-nested-blocks + "PLR1704", # redefined-argument-from-local + "PLR2004", # magic-value-comparison + "PLR5501", # collapsible-else-if + "PLW0602", # global-variable-not-assigned + "PLW0603", # global-statement + "PLW2901", # redefined-loop-name + "PLC0415", # import-outside-toplevel +] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401", "F403"] +"transformer_engine/pytorch/fp8.py" = ["F401"] +"transformer_engine/pytorch/export.py" = ["F401"] +"transformer_engine/pytorch/attention/dot_product_attention/backends.py" = ["F401"] diff --git a/qa/L0_jax_lint/test.sh b/qa/L0_jax_lint/test.sh index 3f804d3ef9..940d4293d5 100755 --- a/qa/L0_jax_lint/test.sh +++ b/qa/L0_jax_lint/test.sh @@ -1,12 +1,13 @@ # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. +# NOTE: This test is duplicated from pre-commit, and could be deleted. set -e : "${TE_PATH:=/opt/transformerengine}" -pip3 install cpplint==1.6.0 pylint==3.3.1 +pip3 install cpplint==1.6.0 ruff==0.11.4 if [ -z "${PYTHON_ONLY}" ] then cd $TE_PATH @@ -20,5 +21,5 @@ if [ -z "${CPP_ONLY}" ] then cd $TE_PATH echo "Checking Python files" - python3 -m pylint --recursive=y transformer_engine/common transformer_engine/jax + python3 -m ruff check transformer_engine/common transformer_engine/jax fi diff --git a/qa/L0_pytorch_lint/test.sh b/qa/L0_pytorch_lint/test.sh index f08dd8a03d..434fc319f7 100644 --- a/qa/L0_pytorch_lint/test.sh +++ b/qa/L0_pytorch_lint/test.sh @@ -1,12 +1,13 @@ # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. +# NOTE: This script is duplicated from pre-commit checks, and could be deleted. set -e : "${TE_PATH:=/opt/transformerengine}" -pip3 install cpplint==1.6.0 pylint==3.3.1 +pip3 install cpplint==1.6.0 ruff==0.11.4 if [ -z "${PYTHON_ONLY}" ] then cd $TE_PATH @@ -20,5 +21,5 @@ if [ -z "${CPP_ONLY}" ] then cd $TE_PATH echo "Checking Python files" - python3 -m pylint --recursive=y transformer_engine/common transformer_engine/pytorch transformer_engine/debug + python3 -m ruff check transformer_engine/common transformer_engine/pytorch transformer_engine/debug fi diff --git a/transformer_engine/debug/__init__.py b/transformer_engine/debug/__init__.py index 446e192d86..f6c7241706 100644 --- a/transformer_engine/debug/__init__.py +++ b/transformer_engine/debug/__init__.py @@ -7,5 +7,5 @@ try: from . import pytorch from .pytorch.debug_state import set_weight_tensor_tp_group_reduce -except ImportError as e: +except ImportError: pass diff --git a/transformer_engine/debug/features/_test_dummy_feature.py b/transformer_engine/debug/features/_test_dummy_feature.py index f74cd95e9d..1ae4d81b62 100644 --- a/transformer_engine/debug/features/_test_dummy_feature.py +++ b/transformer_engine/debug/features/_test_dummy_feature.py @@ -43,7 +43,7 @@ def inspect_tensor_enabled(self, config, *_args, **_kwargs): """ # Access counter via full module path to ensure we're modifying the same module-level # variable regardless of import context (debug framework vs test import) - import transformer_engine.debug.features._test_dummy_feature as dummy_feature # pylint: disable=import-self + import transformer_engine.debug.features._test_dummy_feature as dummy_feature # noqa: PLW0406 # pylint: disable=import-self dummy_feature._inspect_tensor_enabled_call_count += 1 @@ -56,6 +56,6 @@ def inspect_tensor_enabled(self, config, *_args, **_kwargs): def inspect_tensor(self, _config, *_args, **_kwargs): """This method does nothing but always tracks invocations for testing.""" # Access counter via full module path to ensure shared state across import contexts - import transformer_engine.debug.features._test_dummy_feature as dummy_feature # pylint: disable=import-self + import transformer_engine.debug.features._test_dummy_feature as dummy_feature # noqa: PLW0406 # pylint: disable=import-self dummy_feature._inspect_tensor_call_count += 1 diff --git a/transformer_engine/debug/features/api.py b/transformer_engine/debug/features/api.py index ee9a187b3c..5d7c149fd0 100644 --- a/transformer_engine/debug/features/api.py +++ b/transformer_engine/debug/features/api.py @@ -466,10 +466,7 @@ def output_assertions_hook(self, api_name, ret, **kwargs): assert ret is None if api_name == "modify_tensor": assert type(ret) in get_all_tensor_types() - if ( - type(ret) == torch.Tensor # pylint: disable=unidiomatic-typecheck - and "dtype" in kwargs - ): + if type(ret) is torch.Tensor and "dtype" in kwargs: if kwargs["dtype"] is not None: assert ret.dtype == kwargs["dtype"] diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index ed5fdd4660..c1e593b01e 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -418,8 +418,7 @@ def any_feature_enabled(self) -> bool: self.inspect_tensor_enabled or self.inspect_tensor_postquantize_enabled_rowwise or self.inspect_tensor_postquantize_enabled_columnwise - or self.rowwise_tensor_plan == API_CALL_MODIFY - or self.columnwise_tensor_plan == API_CALL_MODIFY + or API_CALL_MODIFY in (self.rowwise_tensor_plan, self.columnwise_tensor_plan) ): return True if self.parent_quantizer is not None: diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index c081e451a7..6ad4b4e041 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -816,7 +816,6 @@ def batcher( sequence_dim, is_outer, ): - del transpose_batch_sequence, sequence_dim, is_outer if GemmPrimitive.outer_primitive is None: raise RuntimeError("GemmPrimitive.outer_primitive has not been registered") lhs_bdims, _, rhs_bdims, *_ = batch_dims diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 64cccaac6e..fd518f9b09 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3355,10 +3355,7 @@ def forward( assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" assert ( - window_size == (-1, 0) - or window_size == (-1, -1) - or use_fused_attention - or fa_utils.v2_3_plus + window_size in ((-1, 0), (-1, -1)) or use_fused_attention or fa_utils.v2_3_plus ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" flash_attn_fwd = None @@ -4061,9 +4058,7 @@ def attn_forward_func_with_cp( cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None ), "cu_seqlens_padded can not be None for context parallelism and qkv_format = 'thd'!" - sliding_window_attn = ( - window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) - ) + sliding_window_attn = window_size is not None and window_size not in ((-1, 0), (-1, -1)) assert not sliding_window_attn or cp_comm_type in [ "a2a", "all_gather", diff --git a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py index 8580cf4a33..cea02ff8a6 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py @@ -399,9 +399,7 @@ def qgemm( return y # cublas fp8 gemm does not support fp32 bias - use_bias_in_gemm = ( - bias is not None and out_dtype != torch.float32 and bias.dtype != torch.float32 - ) + use_bias_in_gemm = bias is not None and torch.float32 not in (out_dtype, bias.dtype) # Run quantized gemm: y = qw * qx scaled_mm_res = torch._scaled_mm( diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index b80e58fe20..026a445b2e 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -17,7 +17,6 @@ from torch.cuda import _lazy_call, _lazy_init from torch.utils.checkpoint import detach_variable, noop_context_fn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules try: diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index f26a337a4d..bc9fa53a92 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -499,7 +499,9 @@ def pre_first_fuser_forward(self) -> None: f"Weight {group_idx} has requires_grad={weight.requires_grad}, " f"but expected requires_grad={weight_requires_grad}." ) - if type(weight.data) != weight_tensor_type: # pylint: disable=unidiomatic-typecheck + if ( + type(weight.data) is not weight_tensor_type + ): # pylint: disable=unidiomatic-typecheck raise RuntimeError( f"Weight {group_idx} has invalid tensor type " f"(expected {weight_tensor_type.__name__}, " diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 8df929f799..46c15ae557 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -77,7 +77,7 @@ def fuser_forward( if basic_op_kwargs[idx]: raise ValueError("Bias operation forward does not expect keyword arguments") if self._op_idxs["activation"] is None: - activation_op = None # pylint: disable=unused-variable + pass # No activation op needed else: raise NotImplementedError("Activations are not yet supported")