-
Notifications
You must be signed in to change notification settings - Fork 699
Expand file tree
/
Copy pathcuda_runtime.cpp
More file actions
267 lines (240 loc) · 7.75 KB
/
cuda_runtime.cpp
File metadata and controls
267 lines (240 loc) · 7.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../util/cuda_runtime.h"
#include <cublasLt.h>
#include <filesystem>
#include <fstream>
#include <mutex>
#include "../common.h"
#include "../util/cuda_driver.h"
#include "../util/system.h"
#include "common/util/cuda_runtime.h"
namespace transformer_engine {
namespace cuda {
namespace {
// String with build-time CUDA include path
#include "string_path_cuda_include.h"
} // namespace
int num_devices() {
auto query_num_devices = []() -> int {
int count;
NVTE_CHECK_CUDA(cudaGetDeviceCount(&count));
return count;
};
static int num_devices_ = query_num_devices();
return num_devices_;
}
int current_device() {
// Return 0 if CUDA context is not initialized
CUcontext context;
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context);
if (context == nullptr) {
return 0;
}
// Query device from CUDA runtime
int device_id;
NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
return device_id;
}
int sm_arch(int device_id) {
static std::vector<int> cache(num_devices(), -1);
static std::vector<std::once_flag> flags(num_devices());
if (device_id < 0) {
device_id = current_device();
}
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&]() {
cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
cache[device_id] = 10 * prop.major + prop.minor;
};
std::call_once(flags[device_id], init);
return cache[device_id];
}
int sm_count(int device_id) {
static std::vector<int> cache(num_devices(), -1);
static std::vector<std::once_flag> flags(num_devices());
if (device_id < 0) {
device_id = current_device();
}
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&]() {
cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
cache[device_id] = prop.multiProcessorCount;
};
std::call_once(flags[device_id], init);
return cache[device_id];
}
void stream_priority_range(int *low_priority, int *high_priority, int device_id) {
static std::vector<std::pair<int, int>> cache(num_devices());
static std::vector<std::once_flag> flags(num_devices());
if (device_id < 0) {
device_id = current_device();
}
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&]() {
int ori_dev = current_device();
if (device_id != ori_dev) NVTE_CHECK_CUDA(cudaSetDevice(device_id));
int min_pri, max_pri;
NVTE_CHECK_CUDA(cudaDeviceGetStreamPriorityRange(&min_pri, &max_pri));
if (device_id != ori_dev) NVTE_CHECK_CUDA(cudaSetDevice(ori_dev));
cache[device_id] = std::make_pair(min_pri, max_pri);
};
std::call_once(flags[device_id], init);
*low_priority = cache[device_id].first;
*high_priority = cache[device_id].second;
}
bool supports_multicast(int device_id) {
#if CUDART_VERSION >= 12010
// NOTE: This needs to be guarded at compile-time and run-time because the
// CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions.
if (cudart_version() < 12010) {
return false;
}
static std::vector<bool> cache(num_devices(), false);
static std::vector<std::once_flag> flags(num_devices());
if (device_id < 0) {
device_id = current_device();
}
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&]() {
CUdevice cudev;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id);
// Multicast support requires both CUDA12.1 UMD + KMD
int result = 0;
// Check if KMD >= 12.1
int driver_version;
NVTE_CHECK_CUDA(cudaDriverGetVersion(&driver_version));
if (driver_version >= 12010) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev);
}
cache[device_id] = static_cast<bool>(result);
};
std::call_once(flags[device_id], init);
return cache[device_id];
#else
return false;
#endif
}
const std::string &include_directory(bool required) {
static std::string path;
// Update cached path if needed
static bool need_to_check_env = true;
if (path.empty() && required) {
need_to_check_env = true;
}
if (need_to_check_env) {
// Search for CUDA headers in common paths
using Path = std::filesystem::path;
std::vector<std::pair<std::string, Path>> search_paths = {{"NVTE_CUDA_INCLUDE_DIR", ""},
{"CUDA_HOME", ""},
{"CUDA_DIR", ""},
{"", string_path_cuda_include},
{"", "/usr/local/cuda"}};
for (auto &[env, p] : search_paths) {
if (p.empty()) {
p = getenv<Path>(env.c_str());
}
if (!p.empty()) {
if (file_exists(p / "cuda_runtime.h")) {
path = p;
break;
}
if (file_exists(p / "include" / "cuda_runtime.h")) {
path = p / "include";
break;
}
}
}
// Throw exception if path is required but not found
if (path.empty() && required) {
std::string message;
message.reserve(2048);
message += "Could not find cuda_runtime.h in";
bool is_first = true;
for (const auto &[env, p] : search_paths) {
message += is_first ? " " : ", ";
is_first = false;
if (!env.empty()) {
message += env;
message += "=";
}
if (p.empty()) {
message += "<unset>";
} else {
message += p;
}
}
message +=
(". "
"Specify path to CUDA Toolkit headers "
"with NVTE_CUDA_INCLUDE_DIR "
"or disable NVRTC support with NVTE_DISABLE_NVRTC=1.");
NVTE_ERROR(message);
}
need_to_check_env = false;
}
// Return cached path
return path;
}
int include_directory_version(bool required) {
// Header path
const auto &include_dir = cuda::include_directory(false);
if (include_dir.empty()) {
if (required) {
NVTE_ERROR(
"Could not detect version of CUDA Toolkit headers "
"(CUDA Toolkit headers not found).");
}
return -1;
}
// Parse CUDART_VERSION from cuda_runtime_api.h.
const auto header_path = std::filesystem::path(include_dir) / "cuda_runtime_api.h";
std::ifstream header_file(header_path);
if (header_file.is_open()) {
const std::string define_prefix = "#define CUDART_VERSION ";
std::string line;
while (std::getline(header_file, line)) {
const auto pos = line.find(define_prefix);
if (pos == std::string::npos) {
continue;
}
try {
const int version = std::stoi(line.substr(pos + define_prefix.size()));
if (version > 0) {
return version;
}
} catch (...) {
continue;
}
}
}
if (required) {
NVTE_ERROR(
"Could not detect version of CUDA Toolkit headers "
"(Could not parse CUDART_VERSION from ",
header_path.string(), ").");
}
return -1;
}
int cudart_version() {
auto get_version = []() -> int {
int version;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&version));
return version;
};
static int version = get_version();
return version;
}
size_t cublas_version() {
// Cache version to avoid cuBLAS logging overhead
static size_t version = cublasLtGetVersion();
return version;
}
} // namespace cuda
} // namespace transformer_engine