From e915db479adc911fbe873aaa2c7b8a318e3a6409 Mon Sep 17 00:00:00 2001 From: jianzhu Date: Thu, 2 Apr 2026 21:03:39 +0800 Subject: [PATCH] fix(CP, FA): when processing the output of Flash Attn forward pass, the conditional logic in the FA version contains a vulnerability, fix it --- .../dot_product_attention/context_parallel.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 64cccaac6e..09630cc35d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -961,11 +961,10 @@ def cp_p2p_fwd_flash_attn( **fa_forward_kwargs, ) rng_states = None - if not fa_utils.v2_7_0_plus: + if not use_flash_attn_3 and not fa_utils.v2_7_0_plus: out_per_step = fa_outputs[4] softmax_lse_per_step = fa_outputs[5] - if not use_flash_attn_3: - rng_states = fa_outputs[7] + rng_states = fa_outputs[7] else: out_per_step = fa_outputs[0] softmax_lse_per_step = fa_outputs[1] @@ -3006,11 +3005,10 @@ def forward( causal=causal, **fa_forward_kwargs, ) - if not fa_utils.v2_7_0_plus: + if not use_flash_attn_3 and not fa_utils.v2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] + rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] @@ -3544,9 +3542,9 @@ def forward( causal=causal, **fa_forward_kwargs, ) - if not fa_utils.v2_7_0_plus: + if not use_flash_attn_3 and not fa_utils.v2_7_0_plus: out_, softmax_lse = fa_outputs[4], fa_outputs[5] - rng_state = fa_outputs[7] if not use_flash_attn_3 else None + rng_state = fa_outputs[7] else: out_, softmax_lse = fa_outputs[0], fa_outputs[1] rng_state = fa_outputs[3] if not use_flash_attn_3 else None