Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Greptile SummaryThis PR simplifies FA3 discovery by removing a post-install file-copy step and instead relying on Confidence Score: 5/5Safe to merge — the simplification is correct and the two remaining findings are P2 quality/robustness suggestions that do not block the primary use case. All findings are P2: one is a defensive-coding suggestion (wrap the No files require special attention; optional hardening suggested in backends.py and test.sh.
|
| Filename | Overview |
|---|---|
| transformer_engine/pytorch/attention/dot_product_attention/backends.py | Switches FA3 imports from a copied local file to direct from flash_attn_interface import …; the else branch has no ImportError guard so a partially-installed FA3 package would crash the whole module load. |
| transformer_engine/pytorch/attention/dot_product_attention/utils.py | Minor update to installation instructions comment for FA3 — no logic changes, looks correct. |
| qa/L3_pytorch_FA_versions_test/test.sh | Removes file-copy step after FA3 install; FA3 clone always uses HEAD without pinning to the listed version tag, risking non-reproducible CI. |
Flowchart
%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[backends.py module load] --> B{get_pkg_version\nflash-attn-3}
B -- PackageNotFoundError --> C[flash_attn_func_v3 = None\nflash_attn_varlen_func_v3 = None\nflash_attn_with_kvcache_v3 = None]
B -- found --> D[from flash_attn_interface import ...]
D -- success --> E[fa_utils.set_flash_attention_3_params\nv3_is_installed = True]
D -- ImportError not caught --> F[Module load crash]
C --> G[FA3 disabled gracefully]
E --> H[FA3 available for use]
Comments Outside Diff (1)
-
transformer_engine/pytorch/attention/dot_product_attention/backends.py, line 133-151 (link)Unguarded
ImportErrorinelsebranchIf
get_pkg_version("flash-attn-3")succeeds (metadata is found) butflash_attn_interfaceis somehow not importable (e.g., an unusual install layout or brokensys.path), the barefrom flash_attn_interface import ...calls in theelseblock will raise an unhandledImportErrorthat bubbles up and prevents the entirebackendsmodule from loading — crashing TE's import entirely. Wrapping the imports in a nestedtry/except ImportErrorwould degrade gracefully to the "not installed" state instead:try: fa_utils.fa3_version = PkgVersion(get_pkg_version("flash-attn-3")) except PackageNotFoundError: flash_attn_func_v3 = None flash_attn_varlen_func_v3 = None flash_attn_with_kvcache_v3 = None else: try: from flash_attn_interface import flash_attn_func as flash_attn_func_v3 from flash_attn_interface import ( flash_attn_varlen_func as flash_attn_varlen_func_v3, ) from flash_attn_interface import ( flash_attn_with_kvcache as flash_attn_with_kvcache_v3, ) from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 fa_utils.set_flash_attention_3_params() except ImportError: flash_attn_func_v3 = None flash_attn_varlen_func_v3 = None flash_attn_with_kvcache_v3 = None
Reviews (1): Last reviewed commit: "Simplify FA3 discovery" | Re-trigger Greptile
| git clone https://github.com/Dao-AILab/flash-attention.git | ||
| cd flash-attention/hopper && python setup.py install | ||
| python_path=`python -c "import site; print(site.getsitepackages()[0])"` | ||
| mkdir -p $python_path/flash_attn_3 | ||
| cp flash_attn_interface.py $python_path/flash_attn_3/ | ||
| cd ../../ | ||
| fi |
There was a problem hiding this comment.
FA3 install always uses HEAD, not the specified version
For FA3 (fa_version = "3.0.0b1"), the script clones the default branch HEAD without checking out the tag or commit corresponding to that version. If the upstream flash-attention main branch advances its API between runs, tests may pass or fail inconsistently and won't actually validate the 3.0.0b1 release. Adding a git checkout after the clone would pin the test to the intended version:
| git clone https://github.com/Dao-AILab/flash-attention.git | |
| cd flash-attention/hopper && python setup.py install | |
| python_path=`python -c "import site; print(site.getsitepackages()[0])"` | |
| mkdir -p $python_path/flash_attn_3 | |
| cp flash_attn_interface.py $python_path/flash_attn_3/ | |
| cd ../../ | |
| fi | |
| git clone https://github.com/Dao-AILab/flash-attention.git | |
| cd flash-attention && git checkout v${fa_version} && cd hopper && python setup.py install | |
| cd ../../ |
Description
Use FA3 install as is, without copying files.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: