diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 6e239bfb72..bbfc4db5ba 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -34,9 +34,6 @@ do else git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention/hopper && python setup.py install - python_path=`python -c "import site; print(site.getsitepackages()[0])"` - mkdir -p $python_path/flash_attn_3 - cp flash_attn_interface.py $python_path/flash_attn_3/ cd ../../ fi diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 1e7bdaac84..b5ed15f8e0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -6,7 +6,6 @@ from contextlib import nullcontext from importlib.metadata import version as get_pkg_version from importlib.metadata import PackageNotFoundError -import importlib.util import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings @@ -139,35 +138,15 @@ flash_attn_with_kvcache_v3 = None # pass # only print warning if use_flash_attention_3 = True in get_attention_backend else: - if importlib.util.find_spec("flash_attn_3.flash_attn_interface") is not None: - from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3 - from flash_attn_3.flash_attn_interface import ( - flash_attn_varlen_func as flash_attn_varlen_func_v3, - ) - from flash_attn_3.flash_attn_interface import ( - flash_attn_with_kvcache as flash_attn_with_kvcache_v3, - ) - from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 - from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 - elif importlib.util.find_spec("flash_attn_interface") is not None: - warnings.warn( - "flash_attn_interface found outside flash_attn_3 package. " - "Importing directly from flash_attn_interface." - ) - from flash_attn_interface import flash_attn_func as flash_attn_func_v3 - from flash_attn_interface import ( - flash_attn_varlen_func as flash_attn_varlen_func_v3, - ) - from flash_attn_interface import ( - flash_attn_with_kvcache as flash_attn_with_kvcache_v3, - ) - from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 - from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 - else: - raise ModuleNotFoundError( - "flash-attn-3 package is installed but flash_attn_interface module " - "could not be found in flash_attn_3/ or site-packages/." - ) + from flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flash_attn_interface import ( + flash_attn_varlen_func as flash_attn_varlen_func_v3, + ) + from flash_attn_interface import ( + flash_attn_with_kvcache as flash_attn_with_kvcache_v3, + ) + from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 + from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 fa_utils.set_flash_attention_3_params() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 170cb2cd34..13d1347a1e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -135,10 +135,7 @@ class FlashAttentionUtils: # Please follow these instructions to install FA3 v3_installation_steps = """\ (1) git clone https://github.com/Dao-AILab/flash-attention.git -(2) cd flash-attention/hopper && python setup.py install -(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` -(4) mkdir -p $python_path/flash_attn_3 -(5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py""" +(2) cd flash-attention/hopper && python setup.py install""" v3_warning_printed = False @staticmethod