Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions qa/L3_pytorch_FA_versions_test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ do
else
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper && python setup.py install
python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flash_attn_3
cp flash_attn_interface.py $python_path/flash_attn_3/
cd ../../
fi
Comment on lines 35 to 38
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 FA3 install always uses HEAD, not the specified version

For FA3 (fa_version = "3.0.0b1"), the script clones the default branch HEAD without checking out the tag or commit corresponding to that version. If the upstream flash-attention main branch advances its API between runs, tests may pass or fail inconsistently and won't actually validate the 3.0.0b1 release. Adding a git checkout after the clone would pin the test to the intended version:

Suggested change
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper && python setup.py install
python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flash_attn_3
cp flash_attn_interface.py $python_path/flash_attn_3/
cd ../../
fi
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention && git checkout v${fa_version} && cd hopper && python setup.py install
cd ../../


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from contextlib import nullcontext
from importlib.metadata import version as get_pkg_version
from importlib.metadata import PackageNotFoundError
import importlib.util
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
Expand Down Expand Up @@ -139,35 +138,15 @@
flash_attn_with_kvcache_v3 = None
# pass # only print warning if use_flash_attention_3 = True in get_attention_backend
else:
if importlib.util.find_spec("flash_attn_3.flash_attn_interface") is not None:
from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_3.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flash_attn_3.flash_attn_interface import (
flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
)
from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
elif importlib.util.find_spec("flash_attn_interface") is not None:
warnings.warn(
"flash_attn_interface found outside flash_attn_3 package. "
"Importing directly from flash_attn_interface."
)
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flash_attn_interface import (
flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
)
from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
else:
raise ModuleNotFoundError(
"flash-attn-3 package is installed but flash_attn_interface module "
"could not be found in flash_attn_3/ or site-packages/."
)
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flash_attn_interface import (
flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
)
from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3

fa_utils.set_flash_attention_3_params()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,7 @@ class FlashAttentionUtils:
# Please follow these instructions to install FA3
v3_installation_steps = """\
(1) git clone https://github.com/Dao-AILab/flash-attention.git
(2) cd flash-attention/hopper && python setup.py install
(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(4) mkdir -p $python_path/flash_attn_3
(5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py"""
(2) cd flash-attention/hopper && python setup.py install"""
v3_warning_printed = False

@staticmethod
Expand Down
Loading