-
Notifications
You must be signed in to change notification settings - Fork 702
Expand file tree
/
Copy pathpadding.cu
More file actions
387 lines (344 loc) · 15.4 KB
/
padding.cu
File metadata and controls
387 lines (344 loc) · 15.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/padding.h>
#include <cfloat>
#include <iostream>
#include <vector>
#include "../common.h"
#include "../utils.cuh"
namespace transformer_engine {
namespace {
// Parameters to tune
constexpr int n_warps_per_tile = 4;
constexpr int threads_per_block = THREADS_PER_WARP * n_warps_per_tile;
constexpr int desired_load_store_size = 8;
constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB
struct MultiPaddingArgs {
// (input) Data buffers for input tensors
void* input_list[kMaxTensorsPerKernel];
// (output) Data buffers for cast output tensors
void* output_list[kMaxTensorsPerKernel];
// Input matrix heights
int num_rows_list[kMaxTensorsPerKernel];
// Input matrix heights (padded)
int padded_num_rows_list[kMaxTensorsPerKernel];
// Input matrix widths
int row_length_list[kMaxTensorsPerKernel];
// Prefix sum (with leading zero) of CUDA blocks needed for each
// tensor
int block_range[kMaxTensorsPerKernel + 1];
// Number of tensors being processed by kernel
int num_tensors;
};
template <int nvec, typename Type>
__global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiPaddingArgs args) {
using Vec = Vec<Type, nvec>;
// Thread indices
// Note: Block is interpreted as a warp_size x num_warps grid
constexpr int bdimx = THREADS_PER_WARP;
constexpr int bdimy = n_warps_per_tile;
const int tid = threadIdx.x;
const int tidx = tid % bdimx;
const int tidy = tid / bdimx;
const int bid = blockIdx.x;
// Input tensors are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
constexpr int tile_dim_m = THREADS_PER_WARP * nvec;
constexpr int tile_dim_n = THREADS_PER_WARP * nvec;
// Number of nvec x nvec subtiles for each thread to
// load/store
constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
// Find tensor corresponding to block
int tensor_id = 0;
while (args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
const Type* input = reinterpret_cast<const Type*>(args.input_list[tensor_id]);
Type* output = reinterpret_cast<Type*>(args.output_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
const int padded_num_rows = args.padded_num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id];
// Find position of tile within tensor
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
const int tile_id = bid - args.block_range[tensor_id];
const int tile_id_m = tile_id / num_tiles_n;
const int tile_id_n = tile_id % num_tiles_n;
const int tile_row = tile_id_m * tile_dim_m;
const int tile_col = tile_id_n * tile_dim_n;
// Load input and store to registers
// Note: Each thread loads n_iterations subtiles, casts to output
// type, and transposes in registers.
Type local_zero = static_cast<Type>(0.f);
#pragma unroll 4
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy;
const int j1 = tidx;
#pragma unroll
for (int i2 = 0; i2 < nvec; ++i2) {
const int row = tile_row + i1 * nvec + i2;
const int col = tile_col + j1 * nvec;
Vec local_input;
Vec local_output;
local_input.clear();
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
local_input.data.elt[j2] = input[static_cast<size_t>(row) * row_length + col + j2];
}
}
}
#pragma unroll
for (int j2 = 0; j2 < nvec; ++j2) {
local_output.data.elt[j2] = local_input.data.elt[j2];
}
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
output[static_cast<size_t>(row) * row_length + col + j2] = local_output.data.elt[j2];
}
}
} else if (row < padded_num_rows) {
// padding
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
output[static_cast<size_t>(row) * row_length + col + j2] = local_zero;
}
}
}
}
}
}
template <int nvec, typename Type>
__global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(MultiPaddingArgs args) {
using Vec = Vec<Type, nvec>;
// Thread indices
// Note: Block is interpreted as a warp_size x num_warps grid
constexpr int bdimx = THREADS_PER_WARP;
constexpr int bdimy = n_warps_per_tile;
const int tid = threadIdx.x;
const int tidx = tid % bdimx;
const int tidy = tid / bdimx;
const int bid = blockIdx.x;
// Input tensors are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
constexpr int tile_dim_m = THREADS_PER_WARP * nvec;
constexpr int tile_dim_n = THREADS_PER_WARP * nvec;
// Number of nvec x nvec subtiles for each thread to
// load/store
constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
// Find tensor corresponding to block
int tensor_id = 0;
while (args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
const Type* input = reinterpret_cast<const Type*>(args.input_list[tensor_id]);
Type* output = reinterpret_cast<Type*>(args.output_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id];
// Find position of tile within tensor
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
const int tile_id = bid - args.block_range[tensor_id];
const int tile_id_m = tile_id / num_tiles_n;
const int tile_id_n = tile_id % num_tiles_n;
const int tile_row = tile_id_m * tile_dim_m;
const int tile_col = tile_id_n * tile_dim_n;
// Load input and store to registers
// Note: Each thread loads n_iterations subtiles, casts to output
// type, and transposes in registers.
Type local_zero = static_cast<Type>(0.f);
#pragma unroll 4
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy;
const int j1 = tidx;
#pragma unroll
for (int i2 = 0; i2 < nvec; ++i2) {
const int row = tile_row + i1 * nvec + i2;
const int col = tile_col + j1 * nvec;
Vec local_input;
Vec local_output;
local_input.clear();
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
local_input.data.elt[j2] = input[static_cast<size_t>(row) * row_length + col + j2];
}
}
}
#pragma unroll
for (int j2 = 0; j2 < nvec; ++j2) {
local_output.data.elt[j2] = local_input.data.elt[j2];
}
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
output[static_cast<size_t>(row) * row_length + col + j2] = local_output.data.elt[j2];
}
}
}
}
}
}
} // namespace
void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> output_list,
const std::vector<int> padded_num_rows_list, cudaStream_t stream) {
// Check that number of tensors is valid
NVTE_CHECK(output_list.size() == input_list.size(),
"Number of input and output tensors must match");
if (input_list.empty()) {
return;
}
// Check that tensor properties are valid
DType type = input_list[0]->data.dtype;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = *input_list[tensor_id];
const auto& output = *output_list[tensor_id];
CheckInputTensor(input, "multi_padding_input_" + std::to_string(tensor_id));
CheckInputTensor(output, "multi_padding_output_" + std::to_string(tensor_id));
NVTE_CHECK(input.data.dtype == type, "Input tensor types do not match.");
NVTE_CHECK(output.data.dtype == type, "Output tensor types do not match.");
NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions.");
NVTE_CHECK(output.data.shape[0] == padded_num_rows_list[tensor_id],
"output tensor shape does not match padded input shape.");
}
// Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size * 8 / typeToNumBits(type);
const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size * 8 / typeToNumBits(type);
// Add tensors to kernel argument struct
MultiPaddingArgs kernel_args;
kernel_args.num_tensors = 0;
kernel_args.block_range[0] = 0;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
// Launch kernel if argument struct is full
if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type);
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_padding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
kernel_args.num_tensors = 0;
}
// Calculate number of thread blocks needed for tensor
const int num_rows = input_list[tensor_id]->data.shape[0];
const int padded_num_rows = padded_num_rows_list[tensor_id];
const int row_length = input_list[tensor_id]->data.shape[1];
const int num_tiles_m = (padded_num_rows + tile_dim_m - 1) / tile_dim_m;
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
const int num_tiles = num_tiles_m * num_tiles_n;
// Add tensor to kernel argument struct
const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input_list[tensor_id]->data.dptr);
kernel_args.output_list[pos] = output_list[tensor_id]->data.dptr;
kernel_args.num_rows_list[pos] = num_rows;
kernel_args.padded_num_rows_list[pos] = padded_num_rows;
kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles;
kernel_args.num_tensors++;
}
// Launch kernel
if (kernel_args.num_tensors > 0) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type);
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_padding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
}
}
void multi_unpadding(const std::vector<Tensor*> input_list, std::vector<Tensor*> output_list,
const std::vector<int> unpadded_num_rows_list, cudaStream_t stream) {
// Check that number of tensors is valid
NVTE_CHECK(output_list.size() == input_list.size(),
"Number of input and output tensors must match");
if (input_list.empty()) {
return;
}
// Check that tensor properties are valid
DType type = input_list[0]->data.dtype;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = *input_list[tensor_id];
const auto& output = *output_list[tensor_id];
CheckInputTensor(input, "multi_unpadding_input_" + std::to_string(tensor_id));
CheckInputTensor(output, "multi_unpadding_output_" + std::to_string(tensor_id));
NVTE_CHECK(input.data.dtype == type, "Input tensor types do not match.");
NVTE_CHECK(output.data.dtype == type, "Output tensor types do not match.");
NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions.");
NVTE_CHECK(output.data.shape[0] == unpadded_num_rows_list[tensor_id],
"output tensor shape does not match padded input shape.");
}
// Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type);
const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type);
// Add tensors to kernel argument struct
MultiPaddingArgs kernel_args;
kernel_args.num_tensors = 0;
kernel_args.block_range[0] = 0;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
// Launch kernel if argument struct is full
if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type);
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_unpadding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
kernel_args.num_tensors = 0;
}
// Calculate number of thread blocks needed for tensor
const int num_rows = unpadded_num_rows_list[tensor_id];
const int row_length = input_list[tensor_id]->data.shape[1];
const int num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m;
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
const int num_tiles = num_tiles_m * num_tiles_n;
// Add tensor to kernel argument struct
const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input_list[tensor_id]->data.dptr);
kernel_args.output_list[pos] = output_list[tensor_id]->data.dptr;
kernel_args.num_rows_list[pos] = num_rows;
kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles;
kernel_args.num_tensors++;
}
// Launch kernel
if (kernel_args.num_tensors > 0) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type);
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_unpadding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
}
}
} // namespace transformer_engine
void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list,
const int* padded_num_rows_list, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_padding);
using namespace transformer_engine;
std::vector<Tensor*> input_list_, output_list_;
std::vector<int> padded_num_rows_list_;
for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(convertNVTETensorCheck(input_list[i]));
output_list_.push_back(convertNVTETensorCheck(output_list[i]));
padded_num_rows_list_.push_back(padded_num_rows_list[i]);
}
multi_padding(input_list_, output_list_, padded_num_rows_list_, stream);
}
void nvte_multi_unpadding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list,
const int* unpadded_num_rows_list, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_unpadding);
using namespace transformer_engine;
std::vector<Tensor*> input_list_, output_list_;
std::vector<int> unpadded_num_rows_list_;
for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(convertNVTETensorCheck(input_list[i]));
output_list_.push_back(convertNVTETensorCheck(output_list[i]));
unpadded_num_rows_list_.push_back(unpadded_num_rows_list[i]);
}
multi_unpadding(input_list_, output_list_, unpadded_num_rows_list_, stream);
}