-
Notifications
You must be signed in to change notification settings - Fork 702
Expand file tree
/
Copy pathtest_fusible_ops.py
More file actions
4465 lines (4044 loc) · 163 KB
/
test_fusible_ops.py
File metadata and controls
4465 lines (4044 loc) · 163 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from __future__ import annotations
from collections.abc import Iterable
import functools
import io
import math
import random
from typing import Optional
import pytest
import torch
import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import (
BackwardActivationBias,
BackwardAddRMSNorm,
BackwardLinearAdd,
BackwardLinearScale,
ForwardLinearBiasActivation,
ForwardLinearBiasAdd,
ForwardLinearScaleAdd,
)
from transformer_engine.pytorch import (
QuantizedTensor,
Float8CurrentScalingQuantizer,
Float8Quantizer,
MXFP8Quantizer,
NVFP4Quantizer,
is_bf16_available,
)
from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor
from transformer_engine.pytorch.cpp_extensions.gemm import general_grouped_gemm_for_grouped_tensor
import transformer_engine_torch as tex
# Import utility functions
from utils import (
assert_close,
assert_close_grads,
dtype_tols,
make_recipe,
quantization_tols,
reset_rng_states,
)
# Check for supported quantization schemes
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
# Supported data types
_dtypes: list[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_available(): # bf16 requires sm_80 or higher
_dtypes.append(torch.bfloat16)
# Supported devices
_devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")]
# Supported quantization recipes
_quantization_list: list[Optional[str]] = [None]
if fp8_available:
_quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
_quantization_list.append("mxfp8")
if nvfp4_available:
_quantization_list.append("nvfp4")
def maybe_skip_quantization(
quantization: Optional[str],
*,
dims: Optional[Iterable[int] | int] = None,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
"""Skip test case if a quantization scheme is not supported"""
# Don't skip if there is no quantization
if quantization is None:
return
# Check if quantization scheme is supported on device
if device is not None and torch.device(device).type != "cuda":
pytest.skip("Quantization is only supported on CUDA devices")
if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
# Check dims
if dims is not None:
if not isinstance(dims, Iterable):
dims = (dims,)
if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"):
if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0:
pytest.skip("FP8 GEMMs require dims that are divisible by 16")
elif quantization == "mxfp8":
if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0:
pytest.skip("MXFP8 GEMMs require dims that are divisible by 32")
elif quantization == "nvfp4":
if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0:
pytest.skip("NVFP4 GEMMs require dims that are divisible by 16")
# Check dtype
if dtype is not None:
if quantization == "nvfp4" and dtype != torch.bfloat16:
pytest.skip("NVFP4 quantization is only supported with BF16 data")
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
*,
min: float = 0.0,
max: float = 1.0,
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_quantized: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
The reference tensor is intended for use in plain PyTorch
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
"""
# Random reference tensor
ref = torch.empty(shape, dtype=ref_dtype, device=ref_device)
ref.uniform_(min, max)
# Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype)
if quantization is None:
if test_is_quantized:
raise ValueError("Quantization scheme not provided")
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
elif quantization == "nvfp4":
test = NVFP4Quantizer(
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
test = test.dequantize()
# Make sure reference and test tensors match each other
ref.copy_(test.to(dtype=ref.dtype))
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
class TestSequentialContainer:
"""Tests for sequential container"""
def test_modules(self) -> None:
"""Check that list of modules can be manipulated as expected"""
# Construct sequential container
modules = [
te_ops.Identity(),
te_ops.Identity(),
torch.nn.Identity(),
te_ops.Identity(),
]
model = te_ops.Sequential(*modules)
# Length
assert len(model) == len(modules)
# Iterator
for module1, module2 in zip(model, modules):
assert module1 is module2
# Index by int
for i, module in enumerate(modules):
assert model[i] is module
assert model[i - len(modules)] is module
# Index by slice
model_subset = model[1:-1]
modules_subset = modules[1:-1]
assert isinstance(model_subset, te_ops.Sequential)
for module1, module2 in zip(model_subset, modules_subset):
assert module1 is module2
# Set element
new_module = torch.nn.Identity()
idx = 1
modules[idx] = new_module
model[idx] = new_module
for module1, module2 in zip(model, modules):
assert module1 is module2
# Delete element
idx = 1
del modules[idx]
del model[idx]
for module1, module2 in zip(model, modules):
assert module1 is module2
# Append
new_module = torch.nn.Identity()
modules.append(new_module)
model.append(new_module)
for module1, module2 in zip(model, modules):
assert module1 is module2
# Extend
new_modules = [te_ops.Identity(), te_ops.Identity()]
modules.extend(new_modules)
model.extend(new_modules)
for module1, module2 in zip(model, modules):
assert module1 is module2
# Insert
new_module = te_ops.Identity()
idx = 2
modules.insert(idx, new_module)
model.insert(idx, new_module)
for module1, module2 in zip(model, modules):
assert module1 is module2
# Pop
idx = 2
assert model.pop(idx) is modules.pop(idx)
for module1, module2 in zip(model, modules):
assert module1 is module2
# Out-of-place add
new_modules = [torch.nn.Identity(), te_ops.Identity()]
added_modules = modules + new_modules
added_model = model + te_ops.Sequential(*new_modules)
for module1, module2 in zip(model, modules):
assert module1 is module2
for module1, module2 in zip(added_model, added_modules):
assert module1 is module2
# In-place add
new_modules = [te_ops.Identity(), torch.nn.Identity()]
modules += new_modules
model += te_ops.Sequential(*new_modules)
for module1, module2 in zip(model, modules):
assert module1 is module2
def test_module_groups(self) -> None:
"""Check that modules are grouped together correctly"""
model = te_ops.Sequential(
te_ops.Identity(),
te_ops.Identity(),
torch.nn.Identity(),
torch.nn.Identity(),
te_ops.Identity(),
torch.nn.Identity(),
te_ops.Identity(),
te_ops.Identity(),
te_ops.Identity(),
)
model(torch.zeros(1))
assert len(model._module_groups) == 6
def test_extra_tensors(self, size: int = 16) -> None:
"""Check that extra inputs are distributed properly between module groups
and that extra outputs are properly collected"""
# Construct sequential container
bias = te_ops.Bias(size=size, device="cpu")
with torch.no_grad():
bias.bias.copy_(torch.rand((size,)))
model = te_ops.Sequential( # | Inputs | Outputs
torch.nn.Identity(), # | x1 | x1
te_ops.MakeExtraOutput(in_place=True), # | x1 | x1 [x1]
bias, # | x1 | h1 (= x1 + b)
te_ops.MakeExtraOutput(in_place=True), # | h1 | h1 [h1]
te_ops.AddExtraInput(in_place=True), # | h1 [x2] | x2 (= x2 + h1)
te_ops.MakeExtraOutput(in_place=True), # | x2 | x2 [x2]
torch.nn.Identity(), # | x2 | x2
bias, # | x2 | h2 (= x2 + b)
te_ops.AddExtraInput(in_place=True), # | h2 [x3] | x3 (= x3 + h2)
te_ops.MakeExtraOutput(in_place=True), # | x3 | x3 [x3]
te_ops.AddExtraInput(in_place=True), # | x3 [x4] | x4 (= x4 + x3)
torch.nn.Identity(), # | x4 | x4
te_ops.Identity(), # | x4 | x4
te_ops.MakeExtraOutput(in_place=True), # | x4 | x4 [x4]
te_ops.Identity(), # | x4 | x4
)
# Create input tensors
x1 = torch.rand((size,))
x2 = torch.rand((size,))
x3 = torch.rand((size,))
x4 = torch.rand((size,))
# Save original input tensor values
x1_orig = x1.clone()
x2_orig = x2.clone()
x3_orig = x3.clone()
x4_orig = x4.clone()
# Run forward
ys = model(x1, x2, x3, x4)
# Check whether outputs match (x4, x1, h1, x2, x3, x4)
assert len(ys) == 6
assert ys[0].data_ptr() == x4.data_ptr()
assert ys[1].data_ptr() == x1.data_ptr()
assert ys[2].data_ptr() not in [x.data_ptr() for x in (x1, x2, x3, x4)]
assert ys[3].data_ptr() == x2.data_ptr()
assert ys[4].data_ptr() == x3.data_ptr()
assert ys[5].data_ptr() == x4.data_ptr()
# Check whether tensors have correct values
b = bias.bias
h1 = ys[2]
torch.testing.assert_close(x1, x1_orig)
torch.testing.assert_close(h1, x1_orig + b)
torch.testing.assert_close(x2, x2_orig + h1)
torch.testing.assert_close(x3, x3_orig + x2 + b)
torch.testing.assert_close(x4, x4_orig + x3)
class TestFuser:
"""Tests for operation fusion infrastructure"""
@staticmethod
def setup_class(cls) -> None:
reset_rng_states()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_scale_update(
self,
size: int = 16,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
"""Test FP8 scaling factors with delayed scaling recipe"""
# FP8 recipe
margin = 2
fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
fp8_format=fp8_format,
amax_history_len=8,
amax_compute_algo="max",
)
# Construct model
with te.quantized_model_init(recipe=recipe):
model = te_ops.basic.BasicLinear(
size,
size,
device=device,
dtype=dtype,
)
# Training steps
w_vals = [2, 5, 3, 11]
x_vals = [7, 3, 5]
dy_vals = [1, 2, 1]
with torch.no_grad():
model.weight.fill_(w_vals[0])
for step in range(3):
# Data tensors
x = torch.full(
(size, size),
x_vals[step],
dtype=dtype,
device=device,
requires_grad=True,
)
dy = torch.full(
(size, size),
dy_vals[step],
dtype=dtype,
device=device,
)
# Training step
with te.autocast(recipe=recipe):
y = model(x)
y.backward(dy)
with torch.no_grad():
model.weight.fill_(w_vals[step + 1])
# Check that output tensors match expected
tols = dict(rtol=0, atol=0)
y_val_ref = w_vals[step] * x_vals[step] * size
dx_val_ref = w_vals[step] * dy_vals[step] * size
torch.testing.assert_close(
y,
torch.full_like(y, y_val_ref),
**quantization_tols("fp8_delayed_scaling"),
)
torch.testing.assert_close(
x.grad,
torch.full_like(x.grad, dx_val_ref),
**quantization_tols("fp8_delayed_scaling"),
)
# Check that scaling factors match expected
w_amax_ref = max(w_vals[: step + 1])
x_amax_ref = max(x_vals[: step + 1])
dy_amax_ref = max(dy_vals[: step + 1])
w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin)
x_scale_ref = (fp8_format.value.max_fwd / x_amax_ref) / (2**margin)
dy_scale_ref = (fp8_format.value.max_bwd / dy_amax_ref) / (2**margin)
w_scale = model.get_quantizer("forward", 1).scale
x_scale = model.get_quantizer("forward", 0).scale
dy_scale = model.get_quantizer("backward", 0).scale
torch.testing.assert_close(w_scale, torch.full_like(w_scale, w_scale_ref))
torch.testing.assert_close(x_scale, torch.full_like(x_scale, x_scale_ref))
torch.testing.assert_close(dy_scale, torch.full_like(dy_scale, dy_scale_ref))
@pytest.mark.parametrize("init_dtype", _dtypes)
@pytest.mark.parametrize("final_dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_dtype_cast(
self,
*,
size: int = 32,
init_dtype: torch.dtype,
final_dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
) -> None:
"""Check dtype cast functions"""
# Skip invalid configurations
in_shape = (size, size)
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=init_dtype)
maybe_skip_quantization(quantization, dtype=final_dtype)
# Random data
dtype = torch.float32
if torch.float16 in (init_dtype, final_dtype):
dtype = torch.float16
if torch.bfloat16 in (init_dtype, final_dtype):
dtype = torch.bfloat16
w_ref, w_test = make_reference_and_test_tensors(
(size, size),
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
# Construct operation
with te.quantized_model_init(enabled=with_quantization, recipe=make_recipe(quantization)):
op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype)
with torch.no_grad():
op.weight.copy_(w_test)
del w_test
# Cast operation dtype
if final_dtype == torch.float32:
op.float()
elif final_dtype == torch.float16:
op.half()
elif final_dtype == torch.bfloat16:
op.bfloat16()
# Check weights
assert isinstance(op.weight, QuantizedTensor) == with_quantization
assert op.weight.dtype == final_dtype
w_test = op.weight.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(w_test, w_ref, **dtype_tols(dtype))
# Check forward and backward pass
x = torch.zeros(
in_shape,
dtype=init_dtype,
device=device,
requires_grad=True,
)
y = op(x)
y.backward(torch.zeros_like(y))
assert y.dtype == final_dtype
assert x.grad.dtype == init_dtype
assert op.weight.grad.dtype == final_dtype
@pytest.mark.parametrize("model_dtype", _dtypes)
@pytest.mark.parametrize("autocast_dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_pyt_autocast(
self,
*,
size: int = 32,
model_dtype: torch.dtype,
autocast_dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weights: bool = False,
) -> None:
"""Test with PyTorch autocast"""
device = torch.device(device)
# Skip invalid configurations
in_shape = (size, size)
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=model_dtype)
maybe_skip_quantization(quantization, dtype=autocast_dtype)
# Construct operation
recipe = make_recipe(quantization)
with te.quantized_model_init(enabled=quantized_weights, recipe=recipe):
op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype)
# Check forward and backward pass
x = torch.zeros(
in_shape,
dtype=model_dtype,
device=device,
requires_grad=True,
)
with te.autocast(enabled=quantized_compute, recipe=recipe):
with torch.autocast(device_type=device.type, dtype=autocast_dtype):
y = op(x)
y.backward(torch.zeros_like(y))
assert y.dtype == autocast_dtype
assert x.grad.dtype == model_dtype
assert op.weight.grad.dtype == model_dtype
# Check forward and backward pass (swapped context order)
if quantized_compute:
x.grad = None
op.weight.grad = None
with torch.autocast(device_type=device.type, dtype=autocast_dtype):
with te.autocast(enabled=quantized_compute, recipe=recipe):
y = op(x)
y.backward(torch.zeros_like(y))
assert y.dtype == autocast_dtype
assert x.grad.dtype == model_dtype
assert op.weight.grad.dtype == model_dtype
class TestBasicOps:
"""Tests for individual operations"""
@staticmethod
def setup_class(cls) -> None:
reset_rng_states()
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_identity(
self,
*,
in_shape: Iterable[int] = (32, 32),
dtype: torch.dtype,
device: torch.device,
quantization: Optional[str],
) -> None:
# Skip invalid configurations
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = x_ref
dx_ref = dy_ref
# Implementation with fusible operation
op = te_ops.Identity()
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
tols = dict(rtol=0, atol=0) # Identity is exact
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)
# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(y_test, -y_ref, **tols)
with pytest.raises(AssertionError):
torch.testing.assert_close(dx_test, -dx_ref, **tols)
@pytest.mark.parametrize(
"shapes",
(
((1, 2, 3, 4), (2, 12)),
((5, 4, 3, 2), (-1, 6)),
((30,), (2, 3, -1)),
((6, 7), (3, -1, 7)),
),
)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", (None, "fp8_current_scaling"))
def test_reshape(
self,
*,
shapes: tuple[Iterable[int], Iterable[int]],
dtype: torch.dtype,
device: torch.device = "cuda",
memory_format: torch.memory_format = torch.contiguous_format,
quantization: Optional[str],
) -> None:
in_shape, out_shape = shapes
# Skip invalid configurations
if memory_format == torch.channels_last and len(in_shape) != 4:
pytest.skip("torch.channels_last only supports 4D tensors")
maybe_skip_quantization(quantization, device=device, dtype=dtype)
with_quantization = quantization is not None
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
)
x_test = x_test.contiguous(memory_format=memory_format)
x_test = x_test.detach().requires_grad_()
dy_ref, dy_test = make_reference_and_test_tensors(
x_ref.reshape(out_shape).size(),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = x_ref.reshape(out_shape)
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.Reshape(out_shape)
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
tols = dict(rtol=0, atol=0) # Reshape is exact
y_test = y_test.to(
dtype=torch.float64,
device="cpu",
memory_format=torch.contiguous_format,
)
dx_test = x_test.grad.to(
dtype=torch.float64,
device="cpu",
memory_format=torch.contiguous_format,
)
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("size", (1, 7, 32))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (4, 3, 8, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_bias(
self,
*,
size: int,
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device,
quantization: Optional[str],
) -> None:
# Make input and bias shapes consistent
in_shape = list(in_shape)[:-1] + [size]
# Skip invalid configurations
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
)
b_ref, b_test = make_reference_and_test_tensors(
size,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=with_quantization,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = x_ref + b_ref.reshape([1] * (len(in_shape) - 1) + [size])
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.Bias(size, device=device, dtype=dtype)
with torch.no_grad():
op.bias.copy_(b_test)
del b_test
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("cast_forward", (False, True))
@pytest.mark.parametrize("cast_backward", (False, True))
def test_quantize(
self,
*,
in_shape: Iterable[int] = (32, 32),
dtype: torch.dtype = torch.bfloat16,
device: torch.device = "cuda",
quantization: str,
cast_forward: bool,
cast_backward: bool,
) -> None:
"""Quantize"""
# Skip invalid configurations
with_quantization = quantization is not None
maybe_skip_quantization(quantization, device=device, dtype=dtype)
if quantization == "mxfp8":
maybe_skip_quantization(quantization, dims=in_shape)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=True,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = x_ref
dx_ref = dy_ref
# Implementation with fusible operation
op = te_ops.Quantize(forward=cast_forward, backward=cast_backward)
recipe = make_recipe(quantization)
with te.autocast(enabled=with_quantization, recipe=recipe):
y_test = op(x_test)
y_test.backward(dy_test)
# Check tensor types
if with_quantization:
assert isinstance(y_test, QuantizedTensor) == cast_forward
assert isinstance(x_test.grad, QuantizedTensor) == cast_backward
# Check values
tols = dict(rtol=0, atol=0)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)
def _test_basic_linear(
self,
*,
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
quantization: Optional[str] = None,
quantized_compute: bool = False,
quantized_input: bool = False,
quantized_weight: bool = False,
quantized_output: bool = False,
quantized_grad_output: bool = False,
quantized_grad_input: bool = False,
accumulate_into_main_grad: bool = False,
) -> None:
"""Helper function for tests with GEMM"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
quantization_needed = any(
(
quantized_compute,
quantized_input,
quantized_weight,
quantized_output,
quantized_grad_output,
quantized_grad_input,
)
)
if quantization is None and quantization_needed:
pytest.skip("Quantization scheme is not specified")
if quantization is not None and not quantization_needed:
pytest.skip("Quantization scheme is not used")
if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"):
if quantized_output and not quantized_compute:
pytest.skip("FP8 output is only supported with FP8 GEMMs")
if quantized_grad_input and not quantized_compute:
pytest.skip("FP8 grad input is only supported with FP8 GEMMs")
if quantization not in (None, "fp8"):
if quantized_output or quantized_grad_input:
pytest.skip("Recipe does not support quantized GEMM output")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=quantized_input,
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=quantized_grad_output,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.BasicLinear(
in_features,
out_features,
device=device,
dtype=dtype,
accumulate_into_main_grad=accumulate_into_main_grad,
)
forward = te_ops.Sequential(
te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input),
op,
te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
)
with torch.no_grad():
op.weight.copy_(w_test)
del w_test
op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32)
with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute or quantized_output or quantized_grad_input:
tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
if accumulate_into_main_grad:
if op.weight.grad is not None:
torch.testing.assert_close(
op.weight.grad,
torch.zeros_like(op.weight.grad),
rtol=0,
atol=0,
)
dw_test = op.weight.main_grad.to(dtype=torch.float64, device="cpu") - 0.5
else:
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(
op.weight.main_grad,
torch.full_like(op.weight.main_grad, 0.5),
rtol=0,
atol=0,
)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
@pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("accumulate_into_main_grad", (False, True))
def test_basic_linear(
self,
*,
weight_shape: tuple[int, int],
in_shape: Iterable[int],
dtype: torch.dtype,
quantization: Optional[str],
accumulate_into_main_grad: bool,
) -> None:
"""GEMM"""
self._test_basic_linear(
weight_shape=weight_shape,
in_shape=in_shape,
dtype=dtype,
quantization=quantization,
quantized_compute=quantization is not None,
accumulate_into_main_grad=accumulate_into_main_grad,
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_input", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("quantized_output", (False, True))
@pytest.mark.parametrize("quantized_grad_output", (False, True))
@pytest.mark.parametrize("quantized_grad_input", (False, True))
def test_basic_linear_quantized(
self,
*,
quantization: str,
quantized_compute: bool,