Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -631,16 +631,13 @@ def get_attention_backend(
if use_flash_attention_2 and (
head_dim_qk > 256
or head_dim_qk % 8 != 0
or (
head_dim_qk > 192
and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0))
)
or (head_dim_qk > 192 and device_compute_capability < (8, 0))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Dead code: condition is always False at this point

Lines 428–431 unconditionally set use_flash_attention_2 = False whenever device_compute_capability < (8, 0). By the time execution reaches line 634, use_flash_attention_2 can only be True if device_compute_capability >= (8, 0), so the sub-expression device_compute_capability < (8, 0) is never true and the entire third or branch is unreachable. The bug-fix intent is correct (no longer blocking head_dim > 192 on sm86/sm89/sm103), but the residual condition could be confusing to future readers who might believe it provides a meaningful guard.

Consider removing the dead branch entirely:

Suggested change
or (head_dim_qk > 192 and device_compute_capability < (8, 0))
or head_dim_qk % 8 != 0

):
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. "
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"head_dim_qk <= 256 (>192 requires sm80/90/100+). "
"head_dim_qk <= 256 (>192 requires sm80+). "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
head_dim_qk,
head_dim_v,
Expand Down