-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathgpu_context_common.h
More file actions
236 lines (203 loc) · 8.53 KB
/
gpu_context_common.h
File metadata and controls
236 lines (203 loc) · 8.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
#ifndef HALIDE_RUNTIME_GPU_CONTEXT_COMMON_H_
#define HALIDE_RUNTIME_GPU_CONTEXT_COMMON_H_
#include "HalideRuntime.h"
#include "printer.h"
#include "scoped_mutex_lock.h"
namespace Halide {
namespace Internal {
template<typename ContextT, typename ModuleStateT>
class GPUCompilationCache {
struct CachedCompilation {
ContextT context{};
ModuleStateT module_state{};
uintptr_t kernel_id{0};
uintptr_t use_count{0};
};
halide_mutex mutex;
static constexpr float kLoadFactor{.5f};
static constexpr int kInitialTableBits{7};
int log2_compilations_size{0}; // number of bits in index into compilations table.
CachedCompilation *compilations{nullptr};
int count{0};
static constexpr uintptr_t kInvalidId{0};
static constexpr uintptr_t kDeletedId{1};
uintptr_t unique_id{2}; // zero is an invalid id
static ALWAYS_INLINE uintptr_t kernel_hash(ContextT context, uintptr_t id, int bits) {
uintptr_t addr = (uintptr_t)context + id;
// Fibonacci hashing. The golden ratio is 1.9E3779B97F4A7C15F39...
// in hexadecimal.
if constexpr (sizeof(uintptr_t) >= 8) {
return (addr * (uintptr_t)0x9E3779B97F4A7C15) >> (64 - bits);
} else {
return (addr * (uintptr_t)0x9E3779B9) >> (32 - bits);
}
}
HALIDE_MUST_USE_RESULT bool insert(const CachedCompilation &entry) {
if (log2_compilations_size == 0) {
if (!resize_table(kInitialTableBits)) {
return false;
}
}
if ((count + 1) > (1 << log2_compilations_size) * kLoadFactor) {
if (!resize_table(log2_compilations_size + 1)) {
return false;
}
}
count += 1;
uintptr_t index = kernel_hash(entry.context, entry.kernel_id, log2_compilations_size);
for (int i = 0; i < (1 << log2_compilations_size); i++) {
uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
if (compilations[effective_index].kernel_id <= kDeletedId) {
compilations[effective_index] = entry;
return true;
}
}
// This is a logic error that should never occur. It means the table is
// full, but it should have been resized.
halide_debug_assert(nullptr, false);
return false;
}
HALIDE_MUST_USE_RESULT bool find_internal(ContextT context, uintptr_t id,
ModuleStateT *&module_state, int increment) {
if (log2_compilations_size == 0) {
return false;
}
uintptr_t index = kernel_hash(context, id, log2_compilations_size);
for (int i = 0; i < (1 << log2_compilations_size); i++) {
uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
if (compilations[effective_index].kernel_id == kInvalidId) {
return false;
}
if (compilations[effective_index].context == context &&
compilations[effective_index].kernel_id == id) {
module_state = &compilations[effective_index].module_state;
if (increment != 0) {
compilations[effective_index].use_count += increment;
}
return true;
}
}
return false;
}
HALIDE_MUST_USE_RESULT bool resize_table(int size_bits) {
if (size_bits != log2_compilations_size) {
int new_size = (1 << size_bits);
int old_size = (1 << log2_compilations_size);
CachedCompilation *new_table = (CachedCompilation *)malloc(new_size * sizeof(CachedCompilation));
if (new_table == nullptr) {
// signal error.
return false;
}
memset(new_table, 0, new_size * sizeof(CachedCompilation));
CachedCompilation *old_table = compilations;
compilations = new_table;
log2_compilations_size = size_bits;
if (count > 0) { // Mainly to catch empty initial table case
for (int32_t i = 0; i < old_size; i++) {
if (old_table[i].kernel_id != kInvalidId &&
old_table[i].kernel_id != kDeletedId) {
bool result = insert(old_table[i]);
halide_debug_assert(nullptr, result); // Resizing the table while resizing the table is a logic error.
(void)result;
}
}
}
free(old_table);
}
return true;
}
template<typename FreeModuleT>
void release_context_already_locked(void *user_context, bool all, ContextT context, FreeModuleT &f) {
if (count == 0) {
return;
}
for (int i = 0; i < (1 << log2_compilations_size); i++) {
if (compilations[i].kernel_id > kDeletedId &&
(all || (compilations[i].context == context)) &&
compilations[i].use_count == 0) {
debug(user_context) << "Releasing cached compilation: " << compilations[i].module_state
<< " id " << compilations[i].kernel_id
<< " context " << compilations[i].context << "\n";
f(compilations[i].module_state);
compilations[i].module_state = nullptr;
compilations[i].kernel_id = kDeletedId;
count--;
}
}
}
public:
HALIDE_MUST_USE_RESULT bool lookup(ContextT context, void *state_ptr, ModuleStateT &module_state) {
ScopedMutexLock lock_guard(&mutex);
uintptr_t id = (uintptr_t)state_ptr;
ModuleStateT *mod_ptr;
if (find_internal(context, id, mod_ptr, 0)) {
module_state = *mod_ptr;
return true;
}
return false;
}
template<typename FreeModuleT>
void release_context(void *user_context, bool all, ContextT context, FreeModuleT &f) {
ScopedMutexLock lock_guard(&mutex);
release_context_already_locked(user_context, all, context, f);
}
template<typename FreeModuleT>
void delete_context(void *user_context, ContextT context, FreeModuleT &f) {
ScopedMutexLock lock_guard(&mutex);
release_context_already_locked(user_context, false, context, f);
}
template<typename FreeModuleT>
void release_all(void *user_context, FreeModuleT &f) {
ScopedMutexLock lock_guard(&mutex);
release_context_already_locked(user_context, true, nullptr, f);
// Some items may have been in use, so can't free.
if (count == 0) {
free(compilations);
compilations = nullptr;
log2_compilations_size = 0;
}
}
template<typename CompileModuleT, typename... Args>
HALIDE_MUST_USE_RESULT bool kernel_state_setup(void *user_context, void **state_ptr_ptr,
ContextT context, ModuleStateT &result,
CompileModuleT f,
Args... args) {
ScopedMutexLock lock_guard(&mutex);
uintptr_t *id_ptr = (uintptr_t *)state_ptr_ptr;
if (*id_ptr == 0) {
*id_ptr = unique_id++;
if (unique_id == (uintptr_t)-1) {
// Sorry, out of ids
return false;
}
}
ModuleStateT *mod;
if (find_internal(context, *id_ptr, mod, 1)) {
result = *mod;
return true;
}
// TODO(zvookin): figure out the calling signature here...
ModuleStateT compiled_module = f(args...);
debug(user_context) << "Caching compiled kernel: " << compiled_module
<< " id " << *id_ptr << " context " << context << "\n";
if (compiled_module == nullptr) {
return false;
}
if (!insert({context, compiled_module, *id_ptr, 1})) {
return false;
}
result = compiled_module;
return true;
}
void release_hold(void *user_context, ContextT context, void *state_ptr) {
ScopedMutexLock lock_guard(&mutex);
ModuleStateT *mod;
uintptr_t id = (uintptr_t)state_ptr;
bool result = find_internal(context, id, mod, -1);
halide_debug_assert(user_context, result); // Value must be in cache to be released
(void)result;
}
};
} // namespace Internal
} // namespace Halide
#endif // HALIDE_RUNTIME_GPU_CONTEXT_COMMON_H_