Skip to content

Commit 01c1958

Browse files
author
da.huo
committed
feat(turbomind): integrate cublasGemmGroupedBatchedEx for Qwen3.5 MoE inference on Blackwell GPUs with memory copy optimizations
Made-with: Cursor
1 parent 2ef9c6b commit 01c1958

File tree

11 files changed

+439
-22
lines changed

11 files changed

+439
-22
lines changed

CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,11 @@ if(ARCH STREQUAL "x86_64")
244244
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "12.8")
245245
list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real) # 5090
246246
endif ()
247+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "12.8")
248+
list(APPEND CMAKE_CUDA_ARCHITECTURES 100a-real) # B200
249+
endif()
247250
if (MSVC)
248-
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES 80-real 90a-real)
251+
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES 80-real 90a-real 100a-real)
249252
endif ()
250253
endif ()
251254
elseif(ARCH STREQUAL "aarch64")

src/turbomind/core/copy.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ const auto& GetCopyAPI()
5757
void* fpn{};
5858
TM_CHECK_EQ(cudaGetDriverEntryPoint(symbol, &fpn, cudaEnableDefault, &status), 0);
5959
if (fpn && status == cudaDriverEntryPointSuccess) {
60+
// cuMemcpyBatchAsync crashes on sm_100 (Blackwell); force monostate -> serialized path.
61+
int device = 0;
62+
(void)cudaGetDevice(&device);
63+
int major = 0;
64+
(void)cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device);
65+
if (major >= 10) {
66+
return {};
67+
}
6068
return (PFN_cuMemcpyBatchAsync_v12080)fpn;
6169
}
6270
else {

src/turbomind/kernels/gemm/CMakeLists.txt

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,44 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22

3+
set(GEMM2_KERNELS_SM70
4+
kernel/sm70_884_4.cu
5+
kernel/sm70_884_8.cu
6+
kernel/sm70_884_16.cu
7+
)
8+
set(GEMM2_KERNELS_SM75
9+
kernel/sm75_16816_4.cu
10+
kernel/sm75_16816_8.cu
11+
kernel/sm75_16816_16.cu
12+
)
13+
set(GEMM2_KERNELS_SM80
14+
kernel/sm80_16816_4.cu
15+
kernel/sm80_16816_8.cu
16+
kernel/sm80_16816_16.cu
17+
)
18+
set(GEMM2_KERNELS_SM90
19+
kernel/sm90_16816_4.cu
20+
kernel/sm90_16816_8.cu
21+
kernel/sm90_16816_16.cu
22+
kernel/sm90_64n32_8.cu
23+
)
24+
25+
set(GEMM2_ARCH_90_ENABLED FALSE)
26+
set(_sm90_archs "${CMAKE_CUDA_ARCHITECTURES}")
27+
list(FILTER _sm90_archs INCLUDE REGEX "^90")
28+
if(_sm90_archs)
29+
set(GEMM2_ARCH_90_ENABLED TRUE)
30+
else()
31+
# When building for SM100+ without explicit SM90, still compile SM90 CUTLASS
32+
# kernels so the fat binary can run MoE models on H100 (CUTLASS fused path).
33+
set(_sm100_archs "${CMAKE_CUDA_ARCHITECTURES}")
34+
list(FILTER _sm100_archs INCLUDE REGEX "^100")
35+
if(_sm100_archs)
36+
set(GEMM2_ARCH_90_ENABLED TRUE)
37+
set(_sm90_archs "90")
38+
message(STATUS "GEMM: auto-enabling SM90 CUTLASS kernels for H100 backward compatibility")
39+
endif()
40+
endif()
41+
342
add_library(gemm2
443
gemm.cu
544
kernel.cu
@@ -16,28 +55,25 @@ add_library(gemm2
1655
tuner/sampler.cu
1756
tuner/stopping_criterion.cc
1857
tuner/params.cc
19-
kernel/sm90_16816_4.cu
20-
kernel/sm90_16816_8.cu
21-
kernel/sm90_16816_16.cu
22-
kernel/sm80_16816_4.cu
23-
kernel/sm80_16816_8.cu
24-
kernel/sm80_16816_16.cu
25-
kernel/sm75_16816_4.cu
26-
kernel/sm75_16816_8.cu
27-
kernel/sm75_16816_16.cu
28-
kernel/sm70_884_4.cu
29-
kernel/sm70_884_8.cu
30-
kernel/sm70_884_16.cu
31-
kernel/sm90_64n32_8.cu
58+
${GEMM2_KERNELS_SM70}
59+
${GEMM2_KERNELS_SM75}
60+
${GEMM2_KERNELS_SM80}
3261
cublas.cu
3362
moe_utils_v2.cu
3463
test/test_utils.cu
3564
)
3665

3766
target_link_libraries(gemm2 PRIVATE parser nvidia::cutlass::cutlass CUDA::cuda_driver)
3867

39-
40-
target_compile_definitions(gemm2 PRIVATE -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
68+
# cublasGemmGroupedBatchedEx (CUDA 12.5+): grouped batched GEMM for MoE on SM100
69+
set(_has_sm100 FALSE)
70+
set(_archs_100 "${CMAKE_CUDA_ARCHITECTURES}")
71+
list(FILTER _archs_100 INCLUDE REGEX "^100")
72+
if(_archs_100 AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.5")
73+
set(_has_sm100 TRUE)
74+
target_compile_definitions(gemm2 PRIVATE ENABLE_CUBLAS_GROUPED=1)
75+
message(STATUS "GEMM: ENABLE_CUBLAS_GROUPED=1 (cublasGemmGroupedBatchedEx for MoE on SM100)")
76+
endif()
4177

4278
target_compile_options(gemm2 PRIVATE
4379
$<$<COMPILE_LANGUAGE:CUDA>:
@@ -48,7 +84,26 @@ target_compile_options(gemm2 PRIVATE
4884
set_property(TARGET gemm2 PROPERTY POSITION_INDEPENDENT_CODE ON)
4985
set_property(TARGET gemm2 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
5086

87+
if(GEMM2_ARCH_90_ENABLED)
88+
# SM90 kernels only compile for 90/90a; avoid building them for sm_100.
89+
add_library(gemm2_sm90 STATIC ${GEMM2_KERNELS_SM90})
90+
set_target_properties(gemm2_sm90 PROPERTIES
91+
CUDA_ARCHITECTURES "${_sm90_archs}"
92+
POSITION_INDEPENDENT_CODE ON
93+
CUDA_RESOLVE_DEVICE_SYMBOLS ON
94+
)
95+
target_compile_definitions(gemm2_sm90 PRIVATE -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
96+
target_compile_options(gemm2_sm90 PRIVATE
97+
$<$<COMPILE_LANGUAGE:CUDA>:
98+
-Xptxas=-v
99+
--generate-line-info
100+
--threads 16>
101+
)
102+
target_link_libraries(gemm2_sm90 PRIVATE parser nvidia::cutlass::cutlass CUDA::cuda_driver)
103+
target_link_libraries(gemm2 PRIVATE gemm2_sm90)
51104

105+
target_compile_definitions(gemm2 PRIVATE GEMM2_ARCH_90_ENABLED)
106+
endif()
52107

53108
if (BUILD_TEST)
54109
add_executable(test_gemm_v2

src/turbomind/kernels/gemm/arch.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,20 @@ struct Sm80: Arch<800, 900> {
2626
static constexpr int value = 800;
2727
};
2828

29-
struct Sm90: Arch<900> {
29+
struct Sm90: Arch<900, 1000> {
3030
static constexpr int value = 900;
3131
};
3232

33+
// B200 (Blackwell) SM 100
34+
struct Sm100: Arch<1000, 1200> {
35+
static constexpr int value = 1000;
36+
};
37+
38+
// SM12.x (e.g. sm_120): use same CUTLASS SM90 kernel family as pre-PR Sm90+ range
39+
struct Sm120: Arch<1200, 1300> {
40+
static constexpr int value = 1200;
41+
};
42+
3343
inline bool is_arch_compatible(int karch, int darch)
3444
{
3545
switch (karch) {
@@ -42,7 +52,11 @@ inline bool is_arch_compatible(int karch, int darch)
4252
case 800:
4353
return Sm80::is_compatible(darch);
4454
case 900:
45-
return Sm90::is_compatible(darch);
55+
return Sm90::is_compatible(darch) || Sm120::is_compatible(darch);
56+
case 1000:
57+
return Sm100::is_compatible(darch);
58+
case 1200:
59+
return Sm120::is_compatible(darch);
4660
default:
4761
return false;
4862
}

src/turbomind/kernels/gemm/convert_v3.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ std::array<const LayoutConverter*, 2> GetConverters(DataType data_type,
105105
if (weight_type == kHalf || weight_type == kBfloat16) {
106106
constexpr Cvt<uint16_t, uint16_t> W;
107107
if (grouped) {
108+
// SM10.x only: CublasGroupedKernel (cublasGemmGroupedBatchedEx) expects standard (K,N)
109+
if (sm >= 100 && sm < 120)
110+
return {};
108111
// clang-format off
109112
if (sm >= 80) return {W(sm8_, kRow, s16816h | B | _1), {}};
110113
if (sm == 75) return {W(sm75, kRow, s16816h | B | _1), {}};

0 commit comments

Comments
 (0)