Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 89 files
+1 −1 CMakeLists.txt
+2 −0 README.md
+359 −199 include/cudnn_frontend/graph_interface.h
+14 −0 include/cudnn_frontend/graph_properties.h
+7 −7 include/cudnn_frontend/node/diagonal_band_mask.h
+23 −2 include/cudnn_frontend/node/scaled_dot_product_flash_attention.h
+38 −5 include/cudnn_frontend/node/sdpa_fp8_bwd.h
+7 −7 include/cudnn_frontend/node/softmax.h
+202 −192 include/cudnn_frontend/plans.h
+1 −1 include/cudnn_frontend_version.h
+1 −0 python/cudnn/README.md
+25 −1 python/cudnn/__init__.py
+137 −61 python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_dswiglu/api.py
+207 −173 ...cudnn/discrete_grouped_gemm/discrete_grouped_gemm_dswiglu/discrete_B_blockscaled_grouped_gemm_dglu_dbias.py
+146 −61 python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_swiglu/api.py
+241 −128 ...on/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_swiglu/discrete_B_blockscaled_grouped_gemm_glu_bias.py
+37 −8 python/cudnn/discrete_grouped_gemm/discrete_kernel_utils.py
+3 −0 python/cudnn/experimental/__init__.py
+3 −0 python/cudnn/experimental/ops/__init__.py
+1,079 −0 python/cudnn/experimental/ops/sdpa.py
+189 −412 python/cudnn/grouped_gemm/grouped_gemm_dglu/api.py
+0 −4,427 python/cudnn/grouped_gemm/grouped_gemm_dglu/continugous_blockscaled_grouped_gemm_dglu_quant_dbias_fusion.py
+159 −97 python/cudnn/grouped_gemm/grouped_gemm_dglu/moe_blockscaled_grouped_gemm_dglu_dbias.py
+4 −2 python/cudnn/grouped_gemm/grouped_gemm_dswiglu/grouped_gemm_dswiglu_quant.py
+202 −403 python/cudnn/grouped_gemm/grouped_gemm_glu/api.py
+0 −3,713 python/cudnn/grouped_gemm/grouped_gemm_glu/continugous_blockscaled_grouped_gemm_glu_quant_bias_fusion.py
+218 −90 python/cudnn/grouped_gemm/grouped_gemm_glu/moe_blockscaled_grouped_gemm_glu_bias.py
+349 −60 python/cudnn/grouped_gemm/grouped_gemm_quant/api.py
+10 −5 python/cudnn/grouped_gemm/grouped_gemm_quant/grouped_gemm_quant.py
+6 −4 python/cudnn/grouped_gemm/grouped_gemm_swiglu/grouped_gemm_swiglu_quant.py
+36 −7 python/cudnn/grouped_gemm/moe_kernel_helpers.py
+12 −0 python/cudnn/sdpa/__init__.py
+581 −0 python/cudnn/sdpa/api.py
+438 −0 python/cudnn/sdpa/fmha_backward_sm100_2kernel.py
+3,016 −0 python/cudnn/sdpa/fmha_dkdv_d256_sm100.py
+1,968 −0 python/cudnn/sdpa/fmha_dq_d256_sm100.py
+1,143 −0 python/cudnn/sdpa/fmha_utils.py
+784 −0 python/cudnn/sdpa/utils.py
+24 −0 python/cudnn/wrapper.py
+47 −0 python/pygraph/pygraph.cpp
+23 −2 python/pygraph/pygraph.h
+10 −4 python/pygraph/sdpa.cpp
+2 −4 samples/cpp/misc/serialization.cpp
+2 −2 samples/cpp/sdpa/fp16_fwd_with_max_and_sum_exp.cpp
+2 −1 samples/legacy_samples/fp8_flash_mha_sample.cpp
+2 −2 samples/legacy_samples/fp8_flash_mha_sample.h
+1 −1 samples/legacy_samples/test_list.cpp
+4 −4 test/cpp/tensor.cpp
+9 −1 test/python/conftest.py
+152 −0 test/python/fe_api/test_discrete_grouped_gemm_dswiglu.py
+201 −7 test/python/fe_api/test_discrete_grouped_gemm_dswiglu_utils.py
+148 −0 test/python/fe_api/test_discrete_grouped_gemm_swiglu.py
+15 −1 test/python/fe_api/test_discrete_grouped_gemm_swiglu_utils.py
+3 −0 test/python/fe_api/test_fe_api_utils.py
+384 −0 test/python/fe_api/test_grouped_gemm_dglu.py
+19 −8 test/python/fe_api/test_grouped_gemm_dswiglu_utils.py
+389 −0 test/python/fe_api/test_grouped_gemm_glu.py
+391 −0 test/python/fe_api/test_grouped_gemm_quant.py
+45 −22 test/python/fe_api/test_grouped_gemm_quant_utils.py
+28 −12 test/python/fe_api/test_grouped_gemm_swiglu_utils.py
+157 −0 test/python/fe_api/test_sdpa_bwd.py
+352 −0 test/python/fe_api/test_sdpa_bwd_utils.py
+1 −0 test/python/sdpa/fp16.py
+6 −2 test/python/sdpa/fp8.py
+11 −9 test/python/sdpa/mxfp8.py
+4 −1 test/python/sdpa/mxfp8_ref.py
+1 −0 test/python/sdpa/random_config.py
+579 −0 test/python/test_cudnn_sdpa_op.py
+32 −6 test/python/test_mhas_v2.py
+107 −0 test/python/test_sdpa_fp8_serialization.py
+7 −1 tools/cudnn_repro/README.md
+13 −34 tools/cudnn_repro/cudnn_repro/__main__.py
+44 −0 tools/cudnn_repro/cudnn_repro/repro_command.py
+55 −0 tools/cudnn_repro/cudnn_repro/routing.py
+2 −7 tools/cudnn_repro/cudnn_repro/stage1_annotate.py
+67 −15 tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_bwd.py
+168 −0 tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_fp8_bwd.py
+168 −0 tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_fp8_fwd.py
+2 −7 tools/cudnn_repro/cudnn_repro/stage2_build_repro.py
+4 −32 tools/cudnn_repro/cudnn_repro/stage2_build_repro_sdpa_bwd.py
+26 −0 tools/cudnn_repro/cudnn_repro/stage2_build_repro_sdpa_fp8_bwd.py
+26 −0 tools/cudnn_repro/cudnn_repro/stage2_build_repro_sdpa_fp8_fwd.py
+4 −31 tools/cudnn_repro/cudnn_repro/stage2_build_repro_sdpa_fwd.py
+61 −0 tools/cudnn_repro/cudnn_repro/utils.py
+172 −0 tools/cudnn_repro/tests/test_cudnn_repro_bwd.py
+90 −0 tools/cudnn_repro/tests/test_cudnn_repro_closed_loop.py
+229 −0 tools/cudnn_repro/tests/test_cudnn_repro_fp8.py
+25 −0 tools/cudnn_repro/tests/test_cudnn_repro_fp8_closed_loop.py
+94 −0 tools/cudnn_repro/tests/test_cudnn_repro_schema.py
5 changes: 4 additions & 1 deletion transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) ||
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
cudnn_runtime_version >= 91100)) &&
cudnn_runtime_version >= 91100) ||
// 9.20: any head_dim + Blackwell + fprop/bprop + non_paged + any sq
(sm_arch_ >= 100 && cudnn_runtime_version >= 92000 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD)) &&
Comment on lines +343 to +345
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Verify sq=1 + causal/padding_causal fprop support in cuDNN 9.20

The 9.20 condition allows any max_seqlen_q (including sq = 1) with any mask type on non-paged Blackwell layouts. The preceding 9.10.2 fprop path explicitly excluded sq = 1 + causal and sq = 1 + padding_causal on non-paged layouts:

(max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK &&
 attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)

With the 9.20 path (any sq, no mask-type restriction at the head-dim level), sq=1 + causal + non-paged + fprop on Blackwell/cuDNN≥9.20 will now pass this gate — where it was previously blocked. If cuDNN 9.20 lifts this restriction for SM≥100, this is correct. If not, passing this combination to the backend would produce a runtime error. Please confirm whether cuDNN 9.20 actually supports this combination on Blackwell.

// 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
// Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed
(!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 &&
Expand Down