11# Copyright (c) OpenMMLab. All rights reserved.
22
3+ set (GEMM2_KERNELS_SM70
4+ kernel/sm70_884_4.cu
5+ kernel/sm70_884_8.cu
6+ kernel/sm70_884_16.cu
7+ )
8+ set (GEMM2_KERNELS_SM75
9+ kernel/sm75_16816_4.cu
10+ kernel/sm75_16816_8.cu
11+ kernel/sm75_16816_16.cu
12+ )
13+ set (GEMM2_KERNELS_SM80
14+ kernel/sm80_16816_4.cu
15+ kernel/sm80_16816_8.cu
16+ kernel/sm80_16816_16.cu
17+ )
18+ set (GEMM2_KERNELS_SM90
19+ kernel/sm90_16816_4.cu
20+ kernel/sm90_16816_8.cu
21+ kernel/sm90_16816_16.cu
22+ kernel/sm90_64n32_8.cu
23+ )
24+
25+ set (GEMM2_ARCH_90_ENABLED FALSE )
26+ set (_sm90_archs "${CMAKE_CUDA_ARCHITECTURES} " )
27+ list (FILTER _sm90_archs INCLUDE REGEX "^90" )
28+ if (_sm90_archs)
29+ set (GEMM2_ARCH_90_ENABLED TRUE )
30+ else ()
31+ # When building for SM100+ without explicit SM90, still compile SM90 CUTLASS
32+ # kernels so the fat binary can run MoE models on H100 (CUTLASS fused path).
33+ set (_sm100_archs "${CMAKE_CUDA_ARCHITECTURES} " )
34+ list (FILTER _sm100_archs INCLUDE REGEX "^100" )
35+ if (_sm100_archs)
36+ set (GEMM2_ARCH_90_ENABLED TRUE )
37+ set (_sm90_archs "90" )
38+ message (STATUS "GEMM: auto-enabling SM90 CUTLASS kernels for H100 backward compatibility" )
39+ endif ()
40+ endif ()
41+
342add_library (gemm2
443 gemm.cu
544 kernel.cu
@@ -16,28 +55,25 @@ add_library(gemm2
1655 tuner/sampler.cu
1756 tuner/stopping_criterion.cc
1857 tuner/params.cc
19- kernel/sm90_16816_4.cu
20- kernel/sm90_16816_8.cu
21- kernel/sm90_16816_16.cu
22- kernel/sm80_16816_4.cu
23- kernel/sm80_16816_8.cu
24- kernel/sm80_16816_16.cu
25- kernel/sm75_16816_4.cu
26- kernel/sm75_16816_8.cu
27- kernel/sm75_16816_16.cu
28- kernel/sm70_884_4.cu
29- kernel/sm70_884_8.cu
30- kernel/sm70_884_16.cu
31- kernel/sm90_64n32_8.cu
58+ ${GEMM2_KERNELS_SM70}
59+ ${GEMM2_KERNELS_SM75}
60+ ${GEMM2_KERNELS_SM80}
3261 cublas.cu
3362 moe_utils_v2.cu
3463 test /test_utils.cu
3564)
3665
3766target_link_libraries (gemm2 PRIVATE parser nvidia::cutlass::cutlass CUDA::cuda_driver )
3867
39-
40- target_compile_definitions (gemm2 PRIVATE -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED )
68+ # cublasGemmGroupedBatchedEx (CUDA 12.5+): grouped batched GEMM for MoE on SM100
69+ set (_has_sm100 FALSE )
70+ set (_archs_100 "${CMAKE_CUDA_ARCHITECTURES} " )
71+ list (FILTER _archs_100 INCLUDE REGEX "^100" )
72+ if (_archs_100 AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.5" )
73+ set (_has_sm100 TRUE )
74+ target_compile_definitions (gemm2 PRIVATE ENABLE_CUBLAS_GROUPED=1 )
75+ message (STATUS "GEMM: ENABLE_CUBLAS_GROUPED=1 (cublasGemmGroupedBatchedEx for MoE on SM100)" )
76+ endif ()
4177
4278target_compile_options (gemm2 PRIVATE
4379 $<$<COMPILE_LANGUAGE :CUDA >:
@@ -48,7 +84,26 @@ target_compile_options(gemm2 PRIVATE
4884set_property (TARGET gemm2 PROPERTY POSITION_INDEPENDENT_CODE ON )
4985set_property (TARGET gemm2 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON )
5086
87+ if (GEMM2_ARCH_90_ENABLED)
88+ # SM90 kernels only compile for 90/90a; avoid building them for sm_100.
89+ add_library (gemm2_sm90 STATIC ${GEMM2_KERNELS_SM90} )
90+ set_target_properties (gemm2_sm90 PROPERTIES
91+ CUDA_ARCHITECTURES "${_sm90_archs} "
92+ POSITION_INDEPENDENT_CODE ON
93+ CUDA_RESOLVE_DEVICE_SYMBOLS ON
94+ )
95+ target_compile_definitions (gemm2_sm90 PRIVATE -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED )
96+ target_compile_options (gemm2_sm90 PRIVATE
97+ $<$<COMPILE_LANGUAGE :CUDA >:
98+ -Xptxas =-v
99+ --generate -line -info
100+ --threads 16>
101+ )
102+ target_link_libraries (gemm2_sm90 PRIVATE parser nvidia::cutlass::cutlass CUDA::cuda_driver )
103+ target_link_libraries (gemm2 PRIVATE gemm2_sm90 )
51104
105+ target_compile_definitions (gemm2 PRIVATE GEMM2_ARCH_90_ENABLED )
106+ endif ()
52107
53108if (BUILD_TEST)
54109 add_executable (test_gemm_v2
0 commit comments