-
Notifications
You must be signed in to change notification settings - Fork 702
Expand file tree
/
Copy pathgated.cuh
More file actions
198 lines (176 loc) · 9.35 KB
/
gated.cuh
File metadata and controls
198 lines (176 loc) · 9.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file gated.cuh
* \brief Gated dispatcher.
*/
#ifndef TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_
#define TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../transpose/transpose.h"
#include "../../utils.cuh"
#include "../fp8/gated_fp8.cuh"
#include "../mxfp8/gated_mxfp8.cuh"
namespace transformer_engine {
namespace dispatch {
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p,
cudaStream_t stream) {
const Tensor input = *convertNVTETensorCheck(nvte_input);
Tensor *output = convertNVTETensorCheck(nvte_output);
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output", /*allow_empty=*/false);
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim() / 2;
NVTE_CHECK(input.flat_last_dim() % 2 == 0,
"Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [",
input.flat_first_dim(), ", ", input.flat_last_dim(), "].");
NVTE_CHECK(output->flat_last_dim() == cols,
"Wrong output shape. Expected (after flattening) [*, ", cols, "], got [",
output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(output->has_data() || output->has_columnwise_data(),
"Either rowwise or columnwise output data need to be allocated.");
switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
//const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
// sm120 shared memory capapbilities are much smaller than sm100, so we disable TMA kernels on sm120
// KL: It is possible that for fwd, the limits are not exceeded for sm120. To be investigated -
// are there any forward only tests we'd like to keep enabled on sm120?
const bool use_tma_kernels =
(cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120();
if (use_tma_kernels) {
Tensor dummy_grad_tensor;
fp8::cast_gated_tma</*IS_BWD=*/false, ParamOP, ActOP, nullptr>(input, dummy_grad_tensor,
output, p, stream);
} else {
fp8::cast_gated_fwd<ParamOP, ActOP>(input, output, p, stream);
}
if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) {
// FP8 kernel only populates row-wise data, so perform
// transpose separately if needed
Tensor transpose_in, transpose_out, dummy;
transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_in.data.dptr = output->data.dptr;
transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()};
transpose_in.data.dtype = output->data.dtype;
transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_out.data.dptr = output->columnwise_data.dptr;
transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()};
transpose_out.data.dtype = output->data.dtype;
detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
NVTE_CHECK(cols % 32 == 0,
"Invalid input shape. Expected the last dimension to be "
"divisible by 32, but got ",
cols, ".");
if (output->has_data()) {
NVTE_CHECK(is_fp8_dtype(output->data.dtype),
"The type of the output tensor should be FP8.");
}
if (output->has_columnwise_data()) {
NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype),
"The type of the columnwise output tensor should be FP8.");
}
NVTE_CHECK(is_supported_by_CC_100(),
"Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+");
Tensor dummy_grad_tensor;
mxfp8::quantize_gated</*IS_BWD=*/false, ParamOP, ActOP, nullptr>(input, dummy_grad_tensor,
output, p, stream);
break;
}
default:
NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + ".");
}
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input,
NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) {
const Tensor &grad = *(convertNVTETensorCheck(nvte_grad));
const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input);
Tensor *output = convertNVTETensorCheck(nvte_output);
CheckInputTensor(grad, "grad");
CheckInputTensor(gated_input, "gated_input");
CheckOutputTensor(*output, "output", /*allow_empty=*/false);
NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ",
gated_input.flat_last_dim(), ".");
const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
NVTE_CHECK(!is_fp8_dtype(grad.dtype()), "Grad input must be in higher precision.");
NVTE_CHECK(grad.dtype() == gated_input.dtype(), "Types of both inputs must match.");
NVTE_CHECK(grad.flat_first_dim() == rows,
"Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [",
grad.flat_first_dim(), ", ", grad.flat_last_dim(), "].");
NVTE_CHECK(grad.flat_last_dim() == cols,
"Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [",
grad.flat_first_dim(), ", ", grad.flat_last_dim(), "].");
NVTE_CHECK(output->has_data() || output->has_columnwise_data(),
"Either rowwise or columnwise output data need to be allocated.");
NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [",
rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(output->flat_last_dim() == cols * 2,
"Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [",
output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(gated_input.shape() == output->shape(),
"Gated input and output shapes must match. Input shape: ", gated_input.shape(),
", output shape: ", output->shape(), ".");
switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
//const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
// sm120 shared memory capapbilities are much smaller than sm100, so we disable TMA kernels on sm120
const bool use_tma_kernels =
(cols % 32 == 0) && is_supported_by_CC_100() && !is_supported_by_CC_120();
if (use_tma_kernels) {
fp8::cast_gated_tma</*IS_BWD=*/true, ParamOP, ActOP, DActOP>(gated_input, grad, output, p,
stream);
} else {
fp8::cast_gated_bwd<ParamOP, ActOP, DActOP>(gated_input, grad, output, p, stream);
}
if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) {
// FP8 kernel only populates row-wise data, so perform
// transpose separately if needed
Tensor transpose_in, transpose_out, dummy;
transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_in.data.dptr = output->data.dptr;
transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()};
transpose_in.data.dtype = output->data.dtype;
transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_out.data.dptr = output->columnwise_data.dptr;
transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()};
transpose_out.data.dtype = output->data.dtype;
detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
NVTE_CHECK(cols % 32 == 0,
"Invalid input shape. Expected the last dimension to be "
"divisible by 32, but got ",
cols, ".");
if (output->has_data()) {
NVTE_CHECK(is_fp8_dtype(output->data.dtype),
"The type of the output tensor should be FP8.");
}
if (output->has_columnwise_data()) {
NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype),
"The type of the columnwise output tensor should be FP8.");
}
NVTE_CHECK(is_supported_by_CC_100(),
"Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+");
mxfp8::quantize_gated</*IS_BWD=*/true, ParamOP, ActOP, DActOP>(gated_input, grad, output, p,
stream);
break;
}
default:
NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + ".");
}
}
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_