Skip to content

fix(CP, FA): the conditional logic in the FA version contains a vulnerability when processing the output of Flash Attn forward pass#2825

Open
zhujian19891203 wants to merge 1 commit intoNVIDIA:mainfrom
021ai:CP_FA_fix
Open

fix(CP, FA): the conditional logic in the FA version contains a vulnerability when processing the output of Flash Attn forward pass#2825
zhujian19891203 wants to merge 1 commit intoNVIDIA:mainfrom
021ai:CP_FA_fix

Conversation

@zhujian19891203
Copy link
Copy Markdown
Contributor

Description

The conditional logic in the FA version contains a vulnerability when processing the output of Flash Attn forward pass, this maybe caused accidentally.

For example, when I only install FA3, and FA2 is totally not installed, something is wrong.
image

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Just context_parallel.py file

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…he conditional logic in the FA version contains a vulnerability, fix it
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 2, 2026

Greptile Summary

This PR fixes a conditional logic bug in three symmetric locations within context_parallel.py where the Flash Attention forward-pass output indices were selected based solely on fa_utils.v2_7_0_plus, ignoring use_flash_attn_3. When only FA3 is installed (FA2 absent), the original code fell into the old-FA2 branch and incorrectly accessed fa_outputs[4]/fa_outputs[5]/fa_outputs[7], whereas FA3's output is indexed at [0]/[1]. The fix adds not use_flash_attn_3 to the guard so FA3 always routes through the correct else branch regardless of the FA2 version flags.

Confidence Score: 5/5

Safe to merge — the fix is minimal, correct, and consistently applied across all three affected call sites.

All three changes follow the same correct pattern: the use_flash_attn_3 guard is added to prevent FA3 from being misrouted into the old FA2 index scheme. The else branch already handled FA3 correctly, and the rng_state/rng_states assignments are consistent with the pre-existing None path for FA3. No new logic paths are introduced; only the dead-code/wrong-index branch for FA3 is eliminated. No P0/P1 findings remain.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Three symmetric fixes to the FA forward-pass output index selection: adds not use_flash_attn_3 guard so FA3-only installations (no FA2) correctly use output indices [0,1] instead of the FA2-old-format indices [4,5,7].

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[flash_attn_fwd returns fa_outputs] --> B{use_flash_attn_3?}
    B -- Yes --> C[else branch
out = fa_outputs 0
lse = fa_outputs 1
rng = None]
    B -- No --> D{fa_utils.v2_7_0_plus?}
    D -- No
Old FA2 format --> E[if branch
out = fa_outputs 4
lse = fa_outputs 5
rng = fa_outputs 7]
    D -- Yes
New FA2 format --> F[else branch
out = fa_outputs 0
lse = fa_outputs 1
rng = fa_outputs 3]
Loading

Reviews (1): Last reviewed commit: "fix(CP, FA): when processing the output ..." | Re-trigger Greptile

@zhujian19891203 zhujian19891203 changed the title Fix bug: the conditional logic in the FA version contains a vulnerability when processing the output of Flash Attn forward pass Fix(CP, FA): the conditional logic in the FA version contains a vulnerability when processing the output of Flash Attn forward pass Apr 3, 2026
@zhujian19891203 zhujian19891203 changed the title Fix(CP, FA): the conditional logic in the FA version contains a vulnerability when processing the output of Flash Attn forward pass fix(CP, FA): the conditional logic in the FA version contains a vulnerability when processing the output of Flash Attn forward pass Apr 3, 2026
@zhujian19891203
Copy link
Copy Markdown
Contributor Author

zhujian19891203 commented Apr 9, 2026

@cyanguwa Can you take a look at this PR, please?

@ptrendx ptrendx added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Apr 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants