From 32670fde5227a2b6005f77324cec1af6eecf682e Mon Sep 17 00:00:00 2001 From: Pedram Razavi Date: Sat, 4 Apr 2026 11:40:50 -0700 Subject: [PATCH] [PyTorch] Fix FlashAttention 2 head_dim > 192 on sm103 and other architectures Replace the exact-match compute capability allowlist with a >= sm80 range check, matching flash-attn's own gate: https://github.com/Dao-AILab/flash-attention/commit/bbb21d6533e9c8d6537e58843df288af86f57193 The allowlist ((8,0), (9,0), (10,0), (12,0)) missed sm103 (B300), sm89 (L40S), sm86 (A40), and others where FA2 supports head_dim up to 256. The sm103 case was validated on hardware with head_dim=256; the remaining architectures appear to be supported based on flash-attn's >= sm80 guarantee. Signed-off-by: Pedram Razavi --- .../pytorch/attention/dot_product_attention/utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 170cb2cd34..d3eb70a2ad 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -631,16 +631,13 @@ def get_attention_backend( if use_flash_attention_2 and ( head_dim_qk > 256 or head_dim_qk % 8 != 0 - or ( - head_dim_qk > 192 - and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0)) - ) + or (head_dim_qk > 192 and device_compute_capability < (8, 0)) ): if FlashAttentionUtils.is_installed: logger.debug( "Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. " "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " - "head_dim_qk <= 256 (>192 requires sm80/90/100+). " + "head_dim_qk <= 256 (>192 requires sm80+). " "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", head_dim_qk, head_dim_v,