-
Notifications
You must be signed in to change notification settings - Fork 699
Expand file tree
/
Copy pathtest_grouped_gemm.cu
More file actions
778 lines (691 loc) · 31.8 KB
/
test_grouped_gemm.cu
File metadata and controls
778 lines (691 loc) · 31.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cublasLt.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <numeric>
#include <optional>
#include <random>
#include <tuple>
#include <vector>
#include <transformer_engine/cast.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/swizzle.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
enum class InputCase {
kFP8Current,
kBF16,
kMXFP8,
};
enum class ShapeCase {
kAllSame,
kSameFirst,
kSameLast,
kAllDifferent,
};
size_t grouped_setup_workspace_size(const size_t num_tensors) {
const size_t ptr_bytes = num_tensors * sizeof(void*);
const size_t int_bytes = num_tensors * sizeof(int);
// Layout: 8 pointer arrays (A, B, C, D, alpha, beta, a_scale, b_scale) + 6 int arrays
size_t size = 8 * ptr_bytes + 6 * int_bytes;
const size_t alignment = 256;
size = ((size + alignment - 1) / alignment) * alignment;
return size;
}
Tensor make_fp8_operand(const std::string& name, const std::vector<size_t>& shape) {
Tensor input_fp32(name + "_fp32", shape, DType::kFloat32);
const size_t numel = shape[0] * shape[1];
std::vector<float> data(numel);
std::mt19937 gen(std::hash<std::string>{}(name));
// Random mean and stddev -> different amax per tensor -> different scales
std::uniform_real_distribution<float> param_dis(0.1f, 10.0f);
float mean = param_dis(gen);
float stddev = param_dis(gen);
std::normal_distribution<float> dis(mean, stddev);
for (size_t i = 0; i < numel; ++i) {
data[i] = dis(gen);
}
NVTE_CHECK_CUDA(cudaMemcpy(input_fp32.rowwise_dptr(), data.data(),
numel * sizeof(float), cudaMemcpyHostToDevice));
Tensor fp8(name, shape, TypeInfo<fp8e4m3>::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING);
nvte_compute_amax(input_fp32.data(), fp8.data(), 0);
QuantizationConfigWrapper config;
nvte_compute_scale_from_amax(fp8.data(), config, 0);
nvte_quantize(input_fp32.data(), fp8.data(), 0);
return fp8;
}
Tensor make_bf16_operand(const std::string& name, const std::vector<size_t>& shape) {
Tensor t(name, shape, DType::kBFloat16);
const size_t numel = shape[0] * shape[1];
std::vector<__nv_bfloat16> ones(numel, __float2bfloat16(1.0f));
NVTE_CHECK_CUDA(cudaMemcpy(t.rowwise_dptr(), ones.data(),
numel * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice));
return t;
}
// Creates an MXFP8 operand with the correct data layout for GEMM.
// MXFP8 GEMM requirements (scales are along K dimension):
// A transposed -> needs rowwise data/scales
// A non-transposed -> needs columnwise data/scales
// B transposed -> needs columnwise data/scales
// B non-transposed -> needs rowwise data/scales
Tensor make_mxfp8_operand(const std::string& name, const std::vector<size_t>& shape,
bool is_A, bool transposed) {
// Determine which data layout we need
bool use_rowwise, use_colwise;
if (is_A) {
// A: transposed -> rowwise, non-transposed -> columnwise
use_rowwise = transposed;
use_colwise = !transposed;
} else {
// B: transposed -> columnwise, non-transposed -> rowwise (opposite of A!)
use_rowwise = !transposed;
use_colwise = transposed;
}
// Create BF16 input with random data
Tensor input_bf16(name + "_bf16", shape, DType::kBFloat16);
fillUniform(&input_bf16);
// Create MXFP8 tensor with only the required data layout
Tensor mxfp8(name, shape, TypeInfo<fp8e4m3>::dtype, use_rowwise, use_colwise,
NVTE_MXFP8_1D_SCALING);
// Quantize BF16 -> MXFP8
nvte_quantize(input_bf16.data(), mxfp8.data(), 0);
// Create output tensor for swizzled scales (same data shape, same layout)
Tensor mxfp8_swizzled(name + "_swizzled", shape, TypeInfo<fp8e4m3>::dtype,
use_rowwise, use_colwise, NVTE_MXFP8_1D_SCALING);
mxfp8_swizzled.set_with_gemm_swizzled_scales(true); // Must be set BEFORE swizzle call
// Copy quantized data from mxfp8 to mxfp8_swizzled
if (use_rowwise) {
size_t data_bytes = test::bytes(mxfp8.rowwise_shape(), mxfp8.dtype());
NVTE_CHECK_CUDA(cudaMemcpy(mxfp8_swizzled.rowwise_dptr(), mxfp8.rowwise_dptr(),
data_bytes, cudaMemcpyDeviceToDevice));
}
if (use_colwise) {
size_t data_bytes = test::bytes(mxfp8.columnwise_shape(), mxfp8.dtype());
NVTE_CHECK_CUDA(cudaMemcpy(mxfp8_swizzled.columnwise_dptr(), mxfp8.columnwise_dptr(),
data_bytes, cudaMemcpyDeviceToDevice));
}
// Swizzle scales for GEMM
nvte_swizzle_scaling_factors(mxfp8.data(), mxfp8_swizzled.data(), 0);
// Sync to ensure operations are complete
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
return mxfp8_swizzled;
}
struct TestParams {
InputCase input_case;
bool transa;
bool transb;
ShapeCase shape_case;
bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0)
};
// Returns a vector of (M, N, K) tuples for each GEMM in the group.
// M - number of rows in output D
// N - number of columns in output D
// K - reduction dimension shared between A and B
std::vector<std::tuple<size_t, size_t, size_t>> make_shapes(ShapeCase scase) {
switch (scase) {
case ShapeCase::kAllSame:
return {{128, 256, 384}, {128, 256, 384}, {128, 256, 384}};
case ShapeCase::kSameFirst:
// Same M (first dim), varying N and K
return {{128, 256, 384}, {128, 384, 512}, {128, 512, 640}};
case ShapeCase::kSameLast:
// Same N (last dim), varying M and K
return {{128, 256, 384}, {256, 256, 512}, {384, 256, 640}};
case ShapeCase::kAllDifferent:
default:
return {{128, 256, 384}, {256, 384, 512}, {384, 512, 640}};
}
}
void run_grouped_gemm_case(const TestParams& params) {
#if CUBLAS_VERSION < 130300
GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is "
<< CUBLAS_VERSION << ".";
#else
const int compute_capability = getDeviceComputeCapability();
if (compute_capability < blackwellComputeCapability) {
GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer.";
}
if (compute_capability == 120) {
GTEST_SKIP() << "Grouped GEMM is currently unsupported on SM120.";
}
const std::vector<std::tuple<size_t, size_t, size_t>> shapes = make_shapes(params.shape_case);
const size_t num_gemms = shapes.size();
std::vector<Tensor> A_tensors;
std::vector<Tensor> B_tensors;
std::vector<Tensor> D_multi;
A_tensors.reserve(num_gemms);
B_tensors.reserve(num_gemms);
D_multi.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
const std::vector<size_t> a_shape = params.transa ? std::vector<size_t>{N, K}
: std::vector<size_t>{K, N};
const std::vector<size_t> b_shape = params.transb ? std::vector<size_t>{K, M}
: std::vector<size_t>{M, K};
switch (params.input_case) {
case InputCase::kFP8Current: {
A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape));
B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape));
break;
}
case InputCase::kBF16: {
A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape));
B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape));
break;
}
case InputCase::kMXFP8: {
A_tensors.emplace_back(make_mxfp8_operand("A" + std::to_string(i), a_shape,
/*is_A=*/true, params.transa));
B_tensors.emplace_back(make_mxfp8_operand("B" + std::to_string(i), b_shape,
/*is_A=*/false, params.transb));
break;
}
}
D_multi.emplace_back(Tensor("D_multi" + std::to_string(i),
std::vector<size_t>{M, N},
DType::kBFloat16));
}
std::vector<NVTETensor> A_ptrs(num_gemms);
std::vector<NVTETensor> B_ptrs(num_gemms);
std::vector<NVTETensor> D_ptrs(num_gemms);
std::vector<Tensor> workspaces(num_gemms);
std::vector<NVTETensor> workspace_ptrs(num_gemms, nullptr);
std::vector<Tensor*> A_views;
std::vector<Tensor*> B_views;
A_views.reserve(num_gemms);
B_views.reserve(num_gemms);
// Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues)
std::vector<NVTETensor> bias_ptrs(num_gemms, nullptr);
std::vector<NVTETensor> gelu_ptrs(num_gemms, nullptr);
const size_t cublas_ws_bytes = 32ull * 1024 * 1024;
for (size_t i = 0; i < num_gemms; ++i) {
A_ptrs[i] = A_tensors[i].data();
B_ptrs[i] = B_tensors[i].data();
D_ptrs[i] = D_multi[i].data();
workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
workspace_ptrs[i] = workspaces[i].data();
A_views.push_back(&A_tensors[i]);
B_views.push_back(&B_tensors[i]);
}
nvte_multi_tensor_gemm(A_ptrs.data(),
B_ptrs.data(),
D_ptrs.data(),
bias_ptrs.data(),
gelu_ptrs.data(),
static_cast<int>(num_gemms),
params.transa,
params.transb,
false, // grad
workspace_ptrs.data(),
false, // accumulate
false, // use_split_accumulator
0, // sm_count
0);
GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode());
GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode());
std::vector<Tensor> C_tensors;
std::vector<Tensor> D_group_tensors;
C_tensors.reserve(num_gemms);
D_group_tensors.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
(void)K;
if (!params.use_null_c) {
C_tensors.emplace_back(Tensor("C" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(M), static_cast<size_t>(N)},
DType::kBFloat16));
}
D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(M), static_cast<size_t>(N)},
DType::kBFloat16));
NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, bytes(D_group_tensors.back().rowwise_shape(), D_group_tensors.back().dtype())));
}
std::vector<Tensor*> C_views, D_views;
for (size_t i = 0; i < num_gemms; ++i) {
if (!params.use_null_c) {
C_views.push_back(&C_tensors[i]);
}
D_views.push_back(&D_group_tensors[i]);
}
std::optional<GroupedBuffers> grouped_C;
if (!params.use_null_c) {
grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING);
}
GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING);
// Per-matrix alpha/beta (all 1.0 and 0.0 respectively)
Tensor alpha_tensor("alpha", std::vector<size_t>{num_gemms}, DType::kFloat32);
Tensor beta_tensor("beta", std::vector<size_t>{num_gemms}, DType::kFloat32);
std::vector<float> alpha_vals(num_gemms, 1.f);
std::vector<float> beta_vals(num_gemms, 0.f);
NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(),
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(),
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms);
Tensor setup_ws("setup_ws", std::vector<size_t>{setup_ws_bytes}, DType::kByte);
Tensor cublas_ws("cublas_ws", std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
nvte_grouped_gemm(grouped_A.get_handle(),
params.transa,
grouped_B.get_handle(),
params.transb,
params.use_null_c ? nullptr : grouped_C->get_handle(),
grouped_D.get_handle(),
alpha_tensor.data(),
beta_tensor.data(),
setup_ws.data(),
cublas_ws.data(),
nullptr, // config (use defaults)
0);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
// Compare results
for (size_t i = 0; i < num_gemms; ++i) {
Tensor grouped_split("grouped_D" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(std::get<0>(shapes[i])),
static_cast<size_t>(std::get<1>(shapes[i]))},
D_multi[i].dtype());
const size_t offset_bytes = static_cast<size_t>(grouped_D.offsets_host[i]) * grouped_D.elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(),
static_cast<char*>(grouped_D.get_data()) + offset_bytes,
grouped_D.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
grouped_split.to_cpu();
D_multi[i].to_cpu();
auto [atol, rtol] = getTolerances(D_multi[i].dtype());
compareResults("grouped_vs_multi",
grouped_split,
D_multi[i].rowwise_cpu_dptr<bf16>(),
true,
atol,
rtol);
}
#endif // CUBLAS_VERSION >= 130300
}
void run_grouped_gemm_discrete_out_case(const TestParams& params) {
#if CUBLAS_VERSION < 130300
GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is "
<< CUBLAS_VERSION << ".";
#else
const int compute_capability = getDeviceComputeCapability();
if (compute_capability < blackwellComputeCapability) {
GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer.";
}
if (compute_capability == 120) {
GTEST_SKIP() << "Grouped GEMM is currently unsupported on SM120.";
}
const std::vector<std::tuple<size_t, size_t, size_t>> shapes = make_shapes(params.shape_case);
const size_t num_gemms = shapes.size();
std::vector<Tensor> A_tensors;
std::vector<Tensor> B_tensors;
std::vector<Tensor> D_multi;
A_tensors.reserve(num_gemms);
B_tensors.reserve(num_gemms);
D_multi.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
const std::vector<size_t> a_shape = params.transa ? std::vector<size_t>{N, K}
: std::vector<size_t>{K, N};
const std::vector<size_t> b_shape = params.transb ? std::vector<size_t>{K, M}
: std::vector<size_t>{M, K};
switch (params.input_case) {
case InputCase::kFP8Current: {
A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape));
B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape));
break;
}
case InputCase::kBF16: {
A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape));
B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape));
break;
}
case InputCase::kMXFP8: {
A_tensors.emplace_back(make_mxfp8_operand("A" + std::to_string(i), a_shape,
/*is_A=*/true, params.transa));
B_tensors.emplace_back(make_mxfp8_operand("B" + std::to_string(i), b_shape,
/*is_A=*/false, params.transb));
break;
}
}
D_multi.emplace_back(Tensor("D_multi" + std::to_string(i),
std::vector<size_t>{M, N},
DType::kBFloat16));
}
std::vector<NVTETensor> A_ptrs(num_gemms);
std::vector<NVTETensor> B_ptrs(num_gemms);
std::vector<NVTETensor> D_ptrs(num_gemms);
std::vector<Tensor> workspaces(num_gemms);
std::vector<NVTETensor> workspace_ptrs(num_gemms, nullptr);
std::vector<Tensor*> A_views;
std::vector<Tensor*> B_views;
A_views.reserve(num_gemms);
B_views.reserve(num_gemms);
// Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues)
std::vector<NVTETensor> bias_ptrs(num_gemms, nullptr);
std::vector<NVTETensor> gelu_ptrs(num_gemms, nullptr);
const size_t cublas_ws_bytes = 32ull * 1024 * 1024;
for (size_t i = 0; i < num_gemms; ++i) {
A_ptrs[i] = A_tensors[i].data();
B_ptrs[i] = B_tensors[i].data();
D_ptrs[i] = D_multi[i].data();
workspaces[i] =
Tensor("workspace" + std::to_string(i), std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
workspace_ptrs[i] = workspaces[i].data();
A_views.push_back(&A_tensors[i]);
B_views.push_back(&B_tensors[i]);
}
nvte_multi_tensor_gemm(A_ptrs.data(),
B_ptrs.data(),
D_ptrs.data(),
bias_ptrs.data(),
gelu_ptrs.data(),
static_cast<int>(num_gemms),
params.transa,
params.transb,
false, // grad
workspace_ptrs.data(),
false, // accumulate
false, // use_split_accumulator
0, // sm_count
0);
GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode());
GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode());
std::vector<Tensor> C_tensors;
std::vector<Tensor> D_list_tensors;
C_tensors.reserve(num_gemms);
D_list_tensors.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
(void)K;
if (!params.use_null_c) {
C_tensors.emplace_back(
Tensor("C" + std::to_string(i), std::vector<size_t>{M, N}, DType::kBFloat16));
}
D_list_tensors.emplace_back(
Tensor("D_list" + std::to_string(i), std::vector<size_t>{M, N}, DType::kBFloat16));
NVTE_CHECK_CUDA(cudaMemset(D_list_tensors.back().rowwise_dptr(), 0,
bytes(D_list_tensors.back().rowwise_shape(),
D_list_tensors.back().dtype())));
}
std::vector<NVTETensor> C_list_ptrs;
std::vector<NVTETensor> D_list_ptrs;
if (!params.use_null_c) {
C_list_ptrs.reserve(num_gemms);
}
D_list_ptrs.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
if (!params.use_null_c) {
C_list_ptrs.push_back(C_tensors[i].data());
}
D_list_ptrs.push_back(D_list_tensors[i].data());
}
// Per-matrix alpha/beta (all 1.0 and 0.0 respectively)
Tensor alpha_tensor("alpha", std::vector<size_t>{num_gemms}, DType::kFloat32);
Tensor beta_tensor("beta", std::vector<size_t>{num_gemms}, DType::kFloat32);
std::vector<float> alpha_vals(num_gemms, 1.f);
std::vector<float> beta_vals(num_gemms, 0.f);
NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(),
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(),
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms);
Tensor setup_ws("setup_ws", std::vector<size_t>{setup_ws_bytes}, DType::kByte);
Tensor cublas_ws("cublas_ws", std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
nvte_grouped_gemm_with_discrete_out(grouped_A.get_handle(),
params.transa,
grouped_B.get_handle(),
params.transb,
params.use_null_c ? nullptr : C_list_ptrs.data(),
params.use_null_c ? 0 : num_gemms,
D_list_ptrs.data(),
num_gemms,
alpha_tensor.data(),
beta_tensor.data(),
setup_ws.data(),
cublas_ws.data(),
nullptr, // config (use defaults)
0);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
// Compare results
for (size_t i = 0; i < num_gemms; ++i) {
D_list_tensors[i].to_cpu();
D_multi[i].to_cpu();
auto [atol, rtol] = getTolerances(D_multi[i].dtype());
compareResults("grouped_list_vs_multi",
D_list_tensors[i],
D_multi[i].rowwise_cpu_dptr<bf16>(),
true,
atol,
rtol);
}
#endif // CUBLAS_VERSION >= 130300
}
void run_grouped_gemm_discrete_in_case(const TestParams& params) {
#if CUBLAS_VERSION < 130300
GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is "
<< CUBLAS_VERSION << ".";
#else
const int compute_capability = getDeviceComputeCapability();
if (compute_capability < blackwellComputeCapability) {
GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer.";
}
if (compute_capability == 120) {
GTEST_SKIP() << "Grouped GEMM is currently unsupported on SM120.";
}
const std::vector<std::tuple<size_t, size_t, size_t>> shapes = make_shapes(params.shape_case);
const size_t num_gemms = shapes.size();
std::vector<Tensor> A_tensors;
std::vector<Tensor> B_tensors;
std::vector<Tensor> D_multi;
A_tensors.reserve(num_gemms);
B_tensors.reserve(num_gemms);
D_multi.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
const std::vector<size_t> a_shape = params.transa ? std::vector<size_t>{N, K}
: std::vector<size_t>{K, N};
const std::vector<size_t> b_shape = params.transb ? std::vector<size_t>{K, M}
: std::vector<size_t>{M, K};
switch (params.input_case) {
case InputCase::kFP8Current: {
A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape));
B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape));
break;
}
case InputCase::kBF16: {
A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape));
B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape));
break;
}
case InputCase::kMXFP8: {
A_tensors.emplace_back(make_mxfp8_operand("A" + std::to_string(i), a_shape,
/*is_A=*/true, params.transa));
B_tensors.emplace_back(make_mxfp8_operand("B" + std::to_string(i), b_shape,
/*is_A=*/false, params.transb));
break;
}
}
D_multi.emplace_back(Tensor("D_multi" + std::to_string(i),
std::vector<size_t>{M, N},
DType::kBFloat16));
}
std::vector<NVTETensor> A_ptrs(num_gemms);
std::vector<NVTETensor> B_ptrs(num_gemms);
std::vector<NVTETensor> D_ptrs(num_gemms);
std::vector<Tensor> workspaces(num_gemms);
std::vector<NVTETensor> workspace_ptrs(num_gemms, nullptr);
std::vector<Tensor*> A_views;
std::vector<Tensor*> B_views;
A_views.reserve(num_gemms);
B_views.reserve(num_gemms);
// Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues)
std::vector<NVTETensor> bias_ptrs(num_gemms, nullptr);
std::vector<NVTETensor> gelu_ptrs(num_gemms, nullptr);
const size_t cublas_ws_bytes = 32ull * 1024 * 1024;
for (size_t i = 0; i < num_gemms; ++i) {
A_ptrs[i] = A_tensors[i].data();
B_ptrs[i] = B_tensors[i].data();
D_ptrs[i] = D_multi[i].data();
workspaces[i] =
Tensor("workspace" + std::to_string(i), std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
workspace_ptrs[i] = workspaces[i].data();
A_views.push_back(&A_tensors[i]);
B_views.push_back(&B_tensors[i]);
}
nvte_multi_tensor_gemm(A_ptrs.data(),
B_ptrs.data(),
D_ptrs.data(),
bias_ptrs.data(),
gelu_ptrs.data(),
static_cast<int>(num_gemms),
params.transa,
params.transb,
false, // grad
workspace_ptrs.data(),
false, // accumulate
false, // use_split_accumulator
0, // sm_count
0);
GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode());
std::vector<Tensor> C_tensors;
std::vector<Tensor> D_group_tensors;
C_tensors.reserve(num_gemms);
D_group_tensors.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
(void)K;
if (!params.use_null_c) {
C_tensors.emplace_back(Tensor("C" + std::to_string(i),
std::vector<size_t>{M, N},
DType::kBFloat16));
}
D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i),
std::vector<size_t>{M, N},
DType::kBFloat16));
NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0,
bytes(D_group_tensors.back().rowwise_shape(),
D_group_tensors.back().dtype())));
}
std::vector<Tensor*> C_views, D_views;
for (size_t i = 0; i < num_gemms; ++i) {
if (!params.use_null_c) {
C_views.push_back(&C_tensors[i]);
}
D_views.push_back(&D_group_tensors[i]);
}
std::optional<GroupedBuffers> grouped_C;
if (!params.use_null_c) {
grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING);
}
GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING);
// Per-matrix alpha/beta (all 1.0 and 0.0 respectively)
Tensor alpha_tensor("alpha", std::vector<size_t>{num_gemms}, DType::kFloat32);
Tensor beta_tensor("beta", std::vector<size_t>{num_gemms}, DType::kFloat32);
std::vector<float> alpha_vals(num_gemms, 1.f);
std::vector<float> beta_vals(num_gemms, 0.f);
NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(),
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(),
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms);
Tensor setup_ws("setup_ws", std::vector<size_t>{setup_ws_bytes}, DType::kByte);
Tensor cublas_ws("cublas_ws", std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
std::vector<NVTETensor> A_list_ptrs;
A_list_ptrs.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
A_list_ptrs.push_back(A_tensors[i].data());
}
nvte_grouped_gemm_with_discrete_inputA(A_list_ptrs.data(),
num_gemms,
params.transa,
grouped_B.get_handle(),
params.transb,
params.use_null_c ? nullptr : grouped_C->get_handle(),
grouped_D.get_handle(),
alpha_tensor.data(),
beta_tensor.data(),
setup_ws.data(),
cublas_ws.data(),
nullptr, // config (use defaults)
0);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
// Compare results
for (size_t i = 0; i < num_gemms; ++i) {
Tensor grouped_split("grouped_D" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(std::get<0>(shapes[i])),
static_cast<size_t>(std::get<1>(shapes[i]))},
D_multi[i].dtype());
const size_t offset_bytes = static_cast<size_t>(grouped_D.offsets_host[i]) * grouped_D.elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(),
static_cast<char*>(grouped_D.get_data()) + offset_bytes,
grouped_D.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
grouped_split.to_cpu();
D_multi[i].to_cpu();
auto [atol, rtol] = getTolerances(D_multi[i].dtype());
compareResults("grouped_discrete_in_vs_multi",
grouped_split,
D_multi[i].rowwise_cpu_dptr<bf16>(),
true,
atol,
rtol);
}
#endif // CUBLAS_VERSION >= 130300
}
class GroupedGemmTest : public ::testing::TestWithParam<TestParams> {};
TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) {
run_grouped_gemm_case(GetParam());
}
TEST_P(GroupedGemmTest, CompareWithMultiTensorGemmDiscreteOut) {
run_grouped_gemm_discrete_out_case(GetParam());
}
TEST_P(GroupedGemmTest, CompareWithMultiTensorGemmDiscreteIn) {
run_grouped_gemm_discrete_in_case(GetParam());
}
std::string MakeGroupedGemmTestName(const testing::TestParamInfo<GroupedGemmTest::ParamType>& info) {
constexpr const char* kInputNames[] = {"FP8Current", "BF16", "MXFP8"};
constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"};
const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") +
"tb" + (info.param.transb ? "T" : "N");
const std::string null_c = info.param.use_null_c ? "_NullC" : "";
return std::string(kInputNames[static_cast<int>(info.param.input_case)]) + "_" +
kShapeNames[static_cast<int>(info.param.shape_case)] + "_" + layout + null_c;
}
// TestParams: {input_case, transa, transb, shape_case, use_null_c}
const std::vector<TestParams> kTestParams = {
// FP8 tests (each tensor has random mean/stddev -> different scales)
{InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false},
// BF16 tests
{InputCase::kBF16, true, false, ShapeCase::kSameFirst, false},
{InputCase::kBF16, false, true, ShapeCase::kSameLast, false},
{InputCase::kBF16, false, false, ShapeCase::kAllSame, false},
{InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false},
// Test NULL C (valid when beta=0)
{InputCase::kBF16, false, false, ShapeCase::kAllSame, true},
// MXFP8 tests
{InputCase::kMXFP8, true, false, ShapeCase::kAllSame, false},
{InputCase::kMXFP8, true, false, ShapeCase::kAllDifferent, false},
{InputCase::kMXFP8, false, true, ShapeCase::kAllSame, false},
{InputCase::kMXFP8, false, true, ShapeCase::kAllDifferent, false},
{InputCase::kMXFP8, false, false, ShapeCase::kAllSame, false},
{InputCase::kMXFP8, false, false, ShapeCase::kAllDifferent, false},
{InputCase::kMXFP8, false, false, ShapeCase::kSameFirst, false},
// MXFP8 with NULL C
{InputCase::kMXFP8, true, false, ShapeCase::kAllSame, true},
};
INSTANTIATE_TEST_SUITE_P(OperatorTest,
GroupedGemmTest,
::testing::ValuesIn(kTestParams),
MakeGroupedGemmTestName);
} // namespace