-
Notifications
You must be signed in to change notification settings - Fork 700
Expand file tree
/
Copy pathgemm.py
More file actions
2373 lines (2109 loc) · 91.1 KB
/
gemm.py
File metadata and controls
2373 lines (2109 loc) · 91.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX te modules"""
import math
import operator
import os
from collections.abc import Iterable
from dataclasses import dataclass
from functools import partial, reduce, cache
from typing import Tuple, Sequence, Union
from enum import Enum
import warnings
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.sharding import NamedSharding, PartitionSpec
from jax.experimental.custom_partitioning import SdyShardingRule
from transformer_engine_jax import (
get_num_compute_streams,
JAXX_Collective_Op,
get_device_compute_capability,
initialize_cgemm_communicator,
get_cgemm_num_max_streams,
get_grouped_gemm_setup_workspace_size,
)
from .base import BasePrimitive, register_primitive
from .quantization import grouped_quantize
from ..quantize import (
AbstractBaseTensor,
NoScaleTensor,
ScaledTensor,
ScaledTensor1x,
ScaledTensor2x,
GroupedScaledTensor1x,
GroupedNoScaleTensor,
ScalingMode,
Quantizer,
GroupedQuantizer,
QuantizerSet,
noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv,
QuantizeLayout,
)
from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import (
global_mesh_resource,
tpsp_axis_size,
dp_or_fsdp_axis_size,
)
__all__ = [
"CollectiveOp",
"CollectiveOpSet",
"collective_gemm_bootstrap",
"noop_collective_op_set",
"gemm",
"grouped_gemm_copy_group_sizes",
"grouped_gemm",
"sanitize_dims",
"get_non_contracting_dims",
"transpose_dims",
]
num_cublas_streams = get_num_compute_streams()
# Cache whether the CUDA-graphable grouped GEMM implementation is available at import time.
# Calling get_grouped_gemm_setup_workspace_size raises a RuntimeError mentioning "cublas" when
# compiled against cuBLAS < 13.2, in which case the cuda-graphable path is unavailable.
_v2_grouped_gemm_available_reason = ""
try:
get_grouped_gemm_setup_workspace_size(1)
_v2_grouped_gemm_available = True
except RuntimeError as e:
if "cublas" in str(e).lower():
_v2_grouped_gemm_available = False
_v2_grouped_gemm_available_reason = str(e)
else:
raise
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if get_device_compute_capability(0) >= 90:
return 33_554_432
return 4_194_304
def sanitize_dims(ndim: int, dims: Union[int, Sequence[int]]) -> Sequence[int]:
"""Convert relative (negative) indexes to absolute dimension numbers."""
dims_ = dims if isinstance(dims, Iterable) else (dims,)
if len(dims_) == 0:
return dims_
return tuple(ndim + dim if dim < 0 else dim for dim in dims_ if dim is not None)
def get_non_contracting_dims(ndim, contracting_dims):
"""Return a tuple of dimensions not included in the contracting dimensions."""
contracting_dims = sanitize_dims(ndim, contracting_dims)
return tuple(dim for dim in range(ndim) if dim not in contracting_dims)
def transpose_dims(ndim, dims_to_transpose, flatten_axis=-1):
"""Compute the new dimension numbers after transpose."""
if len(dims_to_transpose) == 0:
return dims_to_transpose
flatten_axis = ndim - flatten_axis if flatten_axis > 0 else flatten_axis
transposed_dims = (*range(flatten_axis, ndim), *range(flatten_axis))
return tuple(transposed_dims.index(dim) for dim in dims_to_transpose)
def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool:
lhs, rhs, e4m3, e5m2 = map(
dtypes.canonicalize_dtype,
(
lhs_dtype,
rhs_dtype,
jnp.float8_e4m3fn,
jnp.float8_e5m2,
),
)
# FP8 GEMM supports (e4m3 x e4m3), (e4m3 x e5m2) and (e5m2 x e4m3)
if (lhs is e4m3 and rhs in (e4m3, e5m2)) or (lhs in (e4m3, e5m2) and rhs is e4m3):
return True
# Any other combination of data types is not supported
return False
def _get_gemm_layout(
operand_ndims: Tuple[int, int], contracting_dims: Tuple[Sequence[int], Sequence[int]]
) -> Tuple[bool, bool]:
lhs_contracting, rhs_contracting = map(sanitize_dims, operand_ndims, contracting_dims)
lhs_is_transposed = operand_ndims[0] - 1 not in lhs_contracting
rhs_is_transposed = operand_ndims[1] - 1 in rhs_contracting
return lhs_is_transposed, rhs_is_transposed
def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims):
lhs_q = lhs
rhs_q = rhs
if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None:
lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0])
lhs_is_transposed = lhs.ndim - 1 not in lhs_cdims
need_lhs_colwise = lhs_is_transposed and (
lhs_quantizer.scaling_mode.is_1d_block_scaling()
or not is_fp8_gemm_with_all_layouts_supported()
or lhs_quantizer.scaling_mode.is_nvfp4_scaling
)
flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims)
lhs_q = lhs_quantizer.quantize(
lhs,
is_rowwise=not need_lhs_colwise,
is_colwise=need_lhs_colwise,
flatten_axis=flatten_axis,
)
if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None:
rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1])
rhs_is_transposed = rhs.ndim - 1 in rhs_cdims
need_rhs_colwise = not rhs_is_transposed and (
rhs_quantizer.scaling_mode.is_1d_block_scaling()
or not is_fp8_gemm_with_all_layouts_supported()
or rhs_quantizer.scaling_mode.is_nvfp4_scaling
)
flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1
rhs_q = rhs_quantizer.quantize(
rhs,
is_rowwise=not need_rhs_colwise,
is_colwise=need_rhs_colwise,
flatten_axis=flatten_axis,
)
if isinstance(lhs_q, ScaledTensor2x):
raise TypeError(
"Expected lhs_q to not be ScaledTensor2x after quantization, but got"
f" type={type(lhs_q)}"
)
if isinstance(rhs_q, ScaledTensor2x):
raise TypeError(
"Expected rhs_q to not be ScaledTensor2x after quantization, but got"
f" type={type(rhs_q)}"
)
def has_rht_applied(q: AbstractBaseTensor) -> bool:
return isinstance(q, ScaledTensor1x) and q.has_rht_applied
if has_rht_applied(lhs_q) != has_rht_applied(rhs_q):
raise ValueError(
"With NVFP4_1D_SCALING, if one operand is quantized with RHT, the other must be"
" quantized with RHT as well. This is to ensure the RHT is applied to both and will"
" cancel out in the GEMM."
)
return lhs_q, rhs_q
def _get_nvfp4_tensor_scale_inv(amax):
DATA_DTYPE_MAX = jnp.finfo(jnp.float4_e2m1fn.dtype).max.astype(jnp.float32)
SCALE_DTYPE_MAX = jnp.finfo(jnp.float8_e4m3fn.dtype).max.astype(jnp.float32)
return amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX)
def collective_gemm_bootstrap(
num_total_devices,
num_devices_per_process,
process_id,
tensor_parallel_size,
num_max_streams=3,
compute_stream_priority=0,
communication_stream_priority=0,
num_sm_for_communication=2,
use_ce=True,
aggregate_all_gather=False,
):
"""Initialize NCCL communicators for Collective GEMM operations.
This function sets up the distributed communication infrastructure needed for
tensor parallel collective GEMM operations. It supports two main scenarios:
1. **Multi-device per process**: TP domain = single process
- Each process manages multiple GPUs (num_devices_per_process > 1)
- TP group consists of GPUs within the same process
- Example: 2 processes × 4 GPUs each = 8 total ranks, tp_size=4
2. **Single device per process**: TP domain spans multiple processes
- Each process manages one GPU (num_devices_per_process = 1)
- TP group spans across multiple processes
- Example: 8 processes × 1 GPU each = 8 total ranks, tp_size=4
Args:
num_total_devices (int): Total number of ranks across all processes.
Must be divisible by num_devices_per_process.
num_devices_per_process (int): Number of GPUs per process.
- For multi-device: equals tp_size (e.g., 4 GPUs per process)
- For single-device: equals 1 (1 GPU per process)
process_id (int): Process identifier (0-based).
Must be in range [0, num_total_devices // num_devices_per_process).
tensor_parallel_size (int): Size of tensor parallel groups.
Must divide num_total_devices evenly.
num_max_streams (int, optional): Maximum number of CUDA streams for overlap.
Higher values enable more parallelism but use more GPU resources. Default: 3.
compute_stream_priority (int, optional): Priority for GEMM computation streams.
Lower values = higher priority. Range: 0 (highest) to 3 (lowest). Default: 0.
communication_stream_priority (int, optional): Priority for NCCL communication streams.
Lower values = higher priority. Range: 0 (highest) to 3 (lowest). Default: 0.
num_sm_for_communication (int, optional): Number of streaming multiprocessors
reserved for communication operations. Default: 2.
use_ce (bool, optional): Enable CUDA copy engines for memory transfers.
Can improve performance by offloading memory operations. Default: True.
aggregate_all_gather (bool, optional): Aggregate multiple small all-gather operations
into larger ones for better efficiency. Default: False.
Raises:
AssertionError: If num_total_devices is not divisible by num_devices_per_process,
or if process_id is out of valid range.
AssertionError: If num_devices_per_process is not 1 (Temporary: only single device per process is supported for now)
RuntimeError: If NCCL initialization fails or if configuration
is invalid (e.g., insufficient GPUs).
Example:
# Basic initialization (single device per process)
collective_gemm_bootstrap(
num_total_devices=8,
num_devices_per_process=1,
process_id=0,
tensor_parallel_size=4
)
# Advanced configuration with custom performance settings
collective_gemm_bootstrap(
num_total_devices=8,
num_devices_per_process=1,
process_id=0,
tensor_parallel_size=4,
num_max_streams=5, # More parallelism
compute_stream_priority=1, # Lower compute priority
communication_stream_priority=0, # Higher comm priority
num_sm_for_communication=4, # More SMs for communication
use_ce=True, # Enable copy engines
aggregate_all_gather=True # Aggregate small operations
)
Note:
This function must be called after JAX distributed initialization
and before any collective GEMM operations. Each process should call
this function with its own unique process_id.
"""
if not (num_devices_per_process == 1 and jax.local_device_count() == 1):
raise RuntimeError("Only single device per process is supported at the moment!")
if num_total_devices % num_devices_per_process != 0:
raise ValueError(
f"Invalid num_total_devices={num_total_devices},"
f" num_devices_per_process={num_devices_per_process}"
)
if not 0 <= process_id < num_total_devices:
raise ValueError(f"Invalid process_id={process_id}")
initialize_cgemm_communicator(
num_total_devices,
num_devices_per_process,
process_id,
tensor_parallel_size,
num_max_streams,
compute_stream_priority,
communication_stream_priority,
num_sm_for_communication,
use_ce,
aggregate_all_gather,
)
class CollectiveOp(Enum):
"Enum for Collective Type in Collective GEMM"
NONE = JAXX_Collective_Op.NONE
ALL_GATHER = JAXX_Collective_Op.ALL_GATHER
REDUCE_SCATTER = JAXX_Collective_Op.REDUCE_SCATTER
@property
def is_all_gather(self) -> bool:
"""Check if AllGather"""
return self == CollectiveOp.ALL_GATHER
@property
def is_reduce_scatter(self) -> bool:
"""Check if ReduceScatter"""
return self == CollectiveOp.REDUCE_SCATTER
@property
def is_none(self) -> bool:
"""Check if None"""
return self == CollectiveOp.NONE
@dataclass(frozen=True)
class CollectiveOpSet:
"""
A set of CollectiveOp objects that provide complementary collective GEMM configurations for the Forward and Backward passes through Dense-layers.
"""
forward: CollectiveOp
backward: CollectiveOp
@staticmethod
def create(forward_collective_op: CollectiveOp):
"""Create a set of CollectiveOp for forward and backward passes"""
if forward_collective_op.is_all_gather:
backward_collective_op = CollectiveOp.REDUCE_SCATTER
elif forward_collective_op.is_reduce_scatter:
backward_collective_op = CollectiveOp.ALL_GATHER
else:
backward_collective_op = CollectiveOp.NONE
return CollectiveOpSet(forward=forward_collective_op, backward=backward_collective_op)
noop_collective_op_set = CollectiveOpSet.create(forward_collective_op=CollectiveOp.NONE)
@partial(jax.jit, static_argnums=(1, 2))
def swizzled_scale(scale_inv, flatten_axis, is_colwise):
"Swizzle scale_inv via JAX transpose ops"
original_shape = scale_inv.shape
shape_2d = (math.prod(original_shape[:flatten_axis]), math.prod(original_shape[flatten_axis:]))
if is_colwise:
scale_inv = jnp.transpose(scale_inv.reshape(shape_2d))
cols, rows = shape_2d
else:
rows, cols = shape_2d
reshape = scale_inv.reshape(rows // 128, 4, 32, cols // 4, 4)
swizzled = jnp.transpose(reshape, (0, 3, 2, 1, 4))
return swizzled.reshape(original_shape)
def get_lhs_axis_boundary(lhs_cdims, is_transposed):
"""Get the axis boundary for the LHS operand."""
return max(lhs_cdims) + 1 if is_transposed else min(lhs_cdims)
def get_rhs_axis_boundary(rhs_cdims, is_transposed):
"""Get the axis boundary for the RHS operand."""
return min(rhs_cdims) if is_transposed else max(rhs_cdims) + 1
@cache
def _get_high_precision_accumulation_from_env() -> bool:
"""Read NVTE_FP8_GEMM_HIGH_PRECISION_ACCUMULATION once per process (cached)."""
return os.getenv("NVTE_FP8_GEMM_HIGH_PRECISION_ACCUMULATION", "0") == "1"
def assert_cublas_requirements(scaling_mode, contracting_size, tensor_name):
"""Assert that the given tensor shape and layout meet the requirements for cuBLAS GEMM."""
if scaling_mode != ScalingMode.NO_SCALING:
# Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
alignment = 32 if scaling_mode.is_nvfp4_scaling else 16
if contracting_size % alignment != 0:
raise ValueError(
f"cuBLAS GEMM {tensor_name} tensor's contracting dimension must be a multiple of"
f" {alignment} when using quantized inputs. Got contracting_size={contracting_size}"
)
def _reorder_tpsp_leading(tensor, original_shape):
"""Reorder tensor so the tpsp axis is leading: reshape (dp, n, tpsp, m, ...), transpose (2, 0, 1, 3, ...)."""
assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, (
f"Original_shape[0]={original_shape[0]} is not divisible by"
f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}"
)
assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, (
f"Original_shape[1]={original_shape[1]} is not divisible by"
f" tpsp_axis_size()={tpsp_axis_size()}"
)
reshaped = tensor.reshape(
dp_or_fsdp_axis_size(),
int(original_shape[0] / dp_or_fsdp_axis_size()),
tpsp_axis_size(),
int(original_shape[1] / tpsp_axis_size()),
*original_shape[2:],
)
reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim))
return reordered.reshape(original_shape)
def _reorder_dp_leading(tensor, original_shape):
"""Reorder tensor so the dp axis is leading: reshape (tpsp, dp, n, m, ...), transpose (1, 2, 0, 3, ...)."""
assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, (
f"Original_shape[0]={original_shape[0]} is not divisible by"
f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}"
)
assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, (
f"Original_shape[1]={original_shape[1]} is not divisible by"
f" tpsp_axis_size()={tpsp_axis_size()}"
)
reshaped = tensor.reshape(
tpsp_axis_size(),
dp_or_fsdp_axis_size(),
int(original_shape[0] / dp_or_fsdp_axis_size()),
int(original_shape[1] / tpsp_axis_size()),
*original_shape[2:],
)
reordered = reshaped.transpose(1, 2, 0, 3, *range(4, reshaped.ndim))
return reordered.reshape(original_shape)
class GemmPrimitive(BasePrimitive):
"""
Primitive for cuBLAS GEMM
"""
name = "te_gemm_v2_ffi"
multiple_results = True
impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
alpha,
beta,
out_dtype,
contracting_dims,
scaling_mode,
use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
):
del use_split_accumulator, transpose_batch_sequence
def _dims_are_consecutive(dims):
if len(dims) <= 1:
return True
return sorted(dims) == list(range(min(dims), max(dims) + 1))
# Sanity-check operand layouts and types
operand_ndims = (lhs.ndim, rhs.ndim)
(
lhs_contracting_dims,
rhs_contracting_dims,
) = map(sanitize_dims, operand_ndims, contracting_dims)
if not _dims_are_consecutive(lhs_contracting_dims):
raise ValueError(
"cuBLAS GEMM expected consecutive contracting dimensions for LHS operand, but got "
f"{lhs_contracting_dims}."
)
if not _dims_are_consecutive(rhs_contracting_dims):
raise ValueError(
"cuBLAS GEMM expected consecutive contracting dimensions for RHS operand, but got "
f"{rhs_contracting_dims}."
)
lhs_contracting_size, rhs_contracting_size = map(
lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]),
(lhs.shape, rhs.shape),
(lhs_contracting_dims, rhs_contracting_dims),
)
if lhs_contracting_size != rhs_contracting_size:
raise ValueError(
f"cuBLAS GEMM operands have incompatible contracting dimensions: {lhs.shape} @ idx"
f" {lhs_contracting_dims} X {rhs.shape} @ idx {rhs_contracting_dims}."
)
assert_cublas_requirements(scaling_mode, lhs_contracting_size, "LHS")
assert_cublas_requirements(scaling_mode, rhs_contracting_size, "RHS")
lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims)
if scaling_mode != ScalingMode.NO_SCALING:
if not (
scaling_mode.is_nvfp4_scaling or _compatible_fp8_gemm_dtypes(lhs.dtype, rhs.dtype)
):
raise ValueError(
"cuBLAS GEMM quantized operands have incompatible data types: "
f"{lhs.dtype} x {rhs.dtype}."
)
if not (lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0):
raise ValueError(
"Quantized cuBLAS GEMM requires inverse scaling factors for both operands."
)
if (
scaling_mode != ScalingMode.MXFP8_1D_SCALING
and not is_fp8_gemm_with_all_layouts_supported()
):
if lhs_is_transposed or not rhs_is_transposed:
raise ValueError(
"cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) "
"require non-transposed LHS and transposed RHS operands "
"(`contracting_dims=((-1, ), (-1, ))`)."
)
else:
if lhs.dtype != rhs.dtype:
raise ValueError(
"For TE cuBLAS GEMM for non-quantized inputs, the operand dtypes must be equal."
f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}"
)
# Determine output shape and dtype
if not dtypes.canonicalize_dtype(out_dtype).itemsize > 1:
raise ValueError("cuBLAS GEMM custom op does not support 8-bit quantized output types.")
lhs_non_contracting_shape, rhs_non_contracting_shape = map(
lambda shape, dims: [shape[dim] for dim in range(len(shape)) if dim not in dims],
(lhs.shape, rhs.shape),
(lhs_contracting_dims, rhs_contracting_dims),
)
out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape)
output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
# Adjust output shape for comm+GEMM overlap
if not collective_op.is_none and not is_outer: # Inner abstract
if sequence_dim != 1:
raise ValueError(f"Invalid sequence_dim. Got sequence_dim={sequence_dim}")
overlap_out_shape = list(out_shape).copy()
if collective_op.is_all_gather:
overlap_out_shape[1] *= tpsp_axis_size()
else: # RS
overlap_out_shape[sequence_dim] = (
overlap_out_shape[sequence_dim] // tpsp_axis_size()
)
if out_dtype != jnp.bfloat16:
raise ValueError(f"Unsupported out_dtype={out_dtype}")
output = jax.core.ShapedArray(shape=overlap_out_shape, dtype=out_dtype)
# Validate bias when present (bias.size > 0 means fuse bias)
if bias.size > 0:
if bias.shape != tuple(rhs_non_contracting_shape):
raise ValueError(
"cuBLAS GEMM bias tensor has incorrect shape, "
f"expected ({tuple(rhs_non_contracting_shape)}, ) but found {bias.shape}."
)
if bias.dtype != out_dtype:
raise ValueError(
"cuBLAS GEMM bias tensor has incorrect data type, "
f"expected {out_dtype} but found {bias.dtype}."
)
if alpha.size != 1 or alpha.dtype != jnp.float32:
raise ValueError(
f"Expected alpha to be a single float32 scalar, but got alpha.size={alpha.size},"
f" alpha.dtype={alpha.dtype}"
)
if beta.size != 1 or beta.dtype != jnp.float32:
raise ValueError(
f"Expected beta to be a single float32 scalar, but got beta.size={beta.size},"
f" beta.dtype={beta.dtype}"
)
# Declare cuBLAS workspace
workspace_size = get_cublas_workspace_size_bytes()
# NVFP4 swizzling happen in via nvte kernel instead of JAX transposes
if scaling_mode.is_nvfp4_scaling:
workspace_size += lhs_scale_inv.size + rhs_scale_inv.size
if not collective_op.is_none:
workspace_size *= get_cgemm_num_max_streams()
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
# necessarily 256 bytes aligned, we add some padding to ensure alignment.
workspace_size += 256
workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
return output, workspace
@staticmethod
def outer_abstract(*args, **kwargs):
output, _ = GemmPrimitive.abstract(*args, **kwargs)
return (output,)
@staticmethod
def lowering(
ctx,
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
alpha,
beta,
out_dtype,
contracting_dims,
scaling_mode,
use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
):
del out_dtype, transpose_batch_sequence, sequence_dim, is_outer
lhs_aval, _, rhs_aval, *_ = ctx.avals_in
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout(
(lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims)
)
args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, alpha, beta)
kwargs = {
"scaling_mode": int(scaling_mode.value),
"collective_op": int(collective_op.value),
"lhs_axis_boundary": get_lhs_axis_boundary(lhs_cdims, lhs_transposed),
"rhs_axis_boundary": get_rhs_axis_boundary(rhs_cdims, rhs_transposed),
"lhs_transposed": lhs_transposed,
"rhs_transposed": rhs_transposed,
"use_split_accumulator": use_split_accumulator,
}
return jax.ffi.ffi_lowering(GemmPrimitive.name)(ctx, *args, config=kwargs)
@staticmethod
def impl(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
alpha,
beta,
out_dtype,
contracting_dims,
scaling_mode,
use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
):
if scaling_mode.is_1d_block_scaling():
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout(
(lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims)
)
lhs_flatten_axis = max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims)
rhs_flatten_axis = min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1
if not collective_op.is_none and not is_outer:
# MXFP8 + Collective AG/RS: both sides of flatten_axis must be multiples of 128.
# No padding is needed in this case
lhs_first, lhs_last = math.prod(lhs.shape[:lhs_flatten_axis]), math.prod(
lhs.shape[lhs_flatten_axis:]
)
assert lhs_first % 128 == 0 and lhs_last % 128 == 0, (
"MXFP8 + Collective AG/RS requires LHS dimensions before and after the flatten"
f" axis to be multiples of 128. Got lhs.shape={lhs.shape},"
f" lhs_flatten_axis={lhs_flatten_axis}"
)
rhs_first, rhs_last = math.prod(rhs.shape[:rhs_flatten_axis]), math.prod(
rhs.shape[rhs_flatten_axis:]
)
assert rhs_first % 128 == 0 and rhs_last % 128 == 0, (
"MXFP8 + Collective AG/RS requires LHS dimensions before and after the flatten"
f" axis to be multiples of 128. Got rhs.shape={rhs.shape},"
f" rhs_flatten_axis={rhs_flatten_axis}"
)
# The scale needs to be in good shape for reordering
assert lhs_scale_inv.shape[sequence_dim] % tpsp_axis_size() == 0, (
"MXFP8 + Collective AG/RS requires RHS scale inv sequence dimension to be"
f" multiples of tpsp_axis_size. Got lhs_scale_inv.shape={lhs_scale_inv.shape},"
f" tpsp_axis_size={tpsp_axis_size()}, sequence_dim={sequence_dim}"
)
else:
lhs_scale_inv = apply_padding_to_scale_inv(
lhs_scale_inv,
scaling_mode,
lhs.shape,
lhs_transposed,
lhs_flatten_axis,
)
rhs_scale_inv = apply_padding_to_scale_inv(
rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis
)
# Only perform JAX-based swizzle for MXFP8, NVFP4 swizzle will go though nvte kernel
if scaling_mode.is_mxfp8_scaling:
lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed)
rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed)
# Determine if we need to reorder the tensor so that the input/output are in the correct layout for the collective operation
need_reorder = not transpose_batch_sequence and not is_outer and not collective_op.is_none
# Alter lhs blocks so that CGEMM RS outputs correctly
if need_reorder and collective_op.is_reduce_scatter and lhs.shape[0] != 1:
assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}"
lhs = _reorder_tpsp_leading(lhs, lhs.shape)
if (
need_reorder
and (collective_op.is_reduce_scatter or collective_op.is_all_gather)
and lhs_scale_inv.shape[0] != 1
and scaling_mode.is_1d_block_scaling()
):
assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}"
lhs_scale_inv = _reorder_tpsp_leading(lhs_scale_inv, lhs_scale_inv.shape)
(output, _) = GemmPrimitive.inner_primitive.bind(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
alpha,
beta,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
scaling_mode=scaling_mode,
use_split_accumulator=use_split_accumulator,
transpose_batch_sequence=transpose_batch_sequence,
sequence_dim=sequence_dim,
is_outer=is_outer,
collective_op=collective_op,
)
# Alter output blocks for CGEMM AG
if need_reorder and collective_op.is_all_gather and output.shape[0] != 1:
assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}"
output = _reorder_dp_leading(output, output.shape)
return (output,)
@staticmethod
def outer_impl(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
alpha,
beta,
out_dtype,
contracting_dims,
scaling_mode,
use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
):
return GemmPrimitive.impl(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
alpha,
beta,
out_dtype,
contracting_dims,
scaling_mode,
use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
)
@staticmethod
def batcher(
batched_args,
batch_dims,
out_dtype,
contracting_dims,
scaling_mode,
use_split_accumulator,
collective_op,
transpose_batch_sequence,
sequence_dim,
is_outer,
):
if GemmPrimitive.outer_primitive is None:
raise RuntimeError("GemmPrimitive.outer_primitive has not been registered")
lhs_bdims, _, rhs_bdims, *_ = batch_dims
# Batched GEMM is not supported
if not (lhs_bdims is None and rhs_bdims is None):
raise RuntimeError(
f"Batching is not supported, got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims}"
)
out_bdims = (None,)
return (
GemmPrimitive.outer_primitive.bind(
*batched_args,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
scaling_mode=scaling_mode,
use_split_accumulator=use_split_accumulator,
collective_op=collective_op,
transpose_batch_sequence=transpose_batch_sequence,
sequence_dim=sequence_dim,
is_outer=is_outer,
),
(out_bdims,),
)
@staticmethod
def _parse_operand_output_specs(
arg_infos,
contracting_dims,
transpose_batch_sequence,
collective_op,
scaling_mode,
):
lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos)
gsr = global_mesh_resource()
# Ensure that tensor sequence parallelism is not used via setting tp_resource
if gsr.tp_resource is not None:
if gsr.tp_resource in lhs_specs:
warnings.warn(
"Tensor sequence parallelism is detected as tp_resource='{gsr.tp_resource}'"
" appears in lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource"
" for tensor sequence parallelism to avoid potential issues."
)
lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims)
lhs_non_cdims, rhs_non_cdims = map(
lambda ndim, cdims: tuple(i for i in range(ndim) if i not in cdims),
(lhs_ndim, rhs_ndim),
(lhs_cdims, rhs_cdims),
)
lhs_non_cspecs, lhs_cspecs, rhs_non_cspecs, rhs_cspecs = map(
lambda specs, dims: tuple(specs[i] for i in dims),
(lhs_specs, lhs_specs, rhs_specs, rhs_specs),
(lhs_non_cdims, lhs_cdims, rhs_non_cdims, rhs_cdims),
)
reduce_spec = None
for l in lhs_cspecs:
for r in rhs_cspecs:
if l is not None and l == r:
if reduce_spec is not None:
raise RuntimeError("Multiple reduce dimension is detected!")
reduce_spec = l
sequence_dim = None
# Find sequence dimension in lhs_specs if tensor sequence parallel is enabled
# We only do CollectiveGemm AG on the x or dY thus they always the LHS and have sequence dim
if collective_op.is_all_gather:
try:
tpsp_idx = lhs_specs.index(gsr.tpsp_resource)
except ValueError as exc:
raise ValueError(
f"tpsp_resource '{gsr.tpsp_resource}' is not found in lhs_specs: {lhs_specs}."
" Please check your sharding configuration."
) from exc
sequence_dim = tpsp_idx
if not (sequence_dim == 1) ^ transpose_batch_sequence:
raise ValueError(
"CollectiveGEMM supports only (sequence_dim=1 and"
" transpose_batch_sequence=False) or (sequence_dim=0 and"
f" transpose_batch_sequence=True). Received: sequence_dim={sequence_dim},"
f" transpose_batch_sequence={transpose_batch_sequence}."
)
elif collective_op.is_reduce_scatter:
if reduce_spec != gsr.tpsp_resource:
raise ValueError(
"Only CollectiveGemm RS with the Reduction over the TPSP axis is supported! Got"
f" reduce_spec={reduce_spec}, tpsp_resource={gsr.tpsp_resource}"
)
sequence_dim = int(not transpose_batch_sequence)
if reduce_spec is not None:
# Other non-reduce cdims (if exists) need to be unsharded
lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs)
# Only do AG Sequence dim if not Overlap
if collective_op.is_all_gather:
rhs_cspecs = tuple(
s if s in (reduce_spec, gsr.tpsp_resource) else None for s in rhs_cspecs
)
else:
rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs)
# Non-contracting dims of RHS always needs to be gathered, i.e. for TP + activation_hidden
# No batch-dim check needed as `rhs_non_cspecs` never contains batch-dim.
# In `rhs_specs`, the batch dim appears only in Wgrad GEMM under `rhs_cspecs`.
rhs_non_cspecs = tuple(
None if spec in lhs_non_cspecs else spec for spec in rhs_non_cspecs
)
else:
# Otherwise, require contracting dims of both operands to be unsharded
lhs_cspecs = (None,) * len(lhs_cspecs)
rhs_cspecs = (None,) * len(rhs_cspecs)
# Non-contracting dims of RHS always needs to be gathered along the FSDP axis
rhs_non_cspecs = tuple(
(
None
if spec is not None
and (
spec == gsr.fsdp_resource
or (isinstance(spec, tuple) and gsr.fsdp_resource in spec)
)
else spec
)
for spec in rhs_non_cspecs
)
# Only do AG Sequence dim if not Overlap
if not collective_op.is_all_gather:
# Non-contracting dims of LHS to be gathered along the SP axis.
# Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for
# dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet.
lhs_non_cspecs = tuple(
None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs
)
out_specs = lhs_non_cspecs + rhs_non_cspecs
# Only do AG Sequence dim if not Overlap RS
if collective_op.is_all_gather:
if sequence_dim > len(lhs_non_cspecs):
raise ValueError(
f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs:"
f" {lhs_non_cspecs}"
)
out_specs = out_specs[:sequence_dim] + (None,) + out_specs[sequence_dim + 1 :]
elif collective_op.is_reduce_scatter:
if sequence_dim > len(lhs_non_cspecs):
raise ValueError(
f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs:"
f" {lhs_non_cspecs}"
)
out_specs = (
out_specs[:sequence_dim] + (gsr.tpsp_resource,) + out_specs[sequence_dim + 1 :]
)
# specs = merge(cspecs, non_cspecs)
lhs_specs, rhs_specs = map(
lambda cdims, cspecs, non_cspecs: (
cspecs + non_cspecs if cdims[0] == 0 else non_cspecs + cspecs
),
(lhs_cdims, rhs_cdims),
(lhs_cspecs, rhs_cspecs),
(lhs_non_cspecs, rhs_non_cspecs),
)
# Bias sharding is based on GEMM output before any scatter
bias_specs = rhs_non_cspecs if arg_infos[4].size > 0 else (None,) # bias is operand index 4
# Scale shardings are based on the scaling_mode and collective_op
lhs_scale_specs = rhs_scale_specs = (None,)
if scaling_mode.is_1d_block_scaling():
rhs_scale_specs = rhs_specs
# Set the seq spec to None to trigger AG the scales as TE/Common CGEMM does not handle
# scale collecting yet