Skip to content

Commit 0d1bde3

Browse files
committed
remove tls stream
1 parent 1322ea6 commit 0d1bde3

1 file changed

Lines changed: 1 addition & 20 deletions

File tree

paddle/phi/api/include/compat/c10/cuda/CUDAStream.cpp

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ struct DevicePools {
5050
};
5151

5252
std::vector<std::unique_ptr<DevicePools>> g_pools;
53-
thread_local std::vector<std::unique_ptr<phi::CUDAStream>> tls_current_streams;
54-
thread_local bool tls_streams_initialized = false;
5553

5654
void initGlobalState() {
5755
std::call_once(g_init_once, []() {
@@ -101,13 +99,6 @@ inline void check_gpu(c10::DeviceIndex device_index) {
10199
")");
102100
}
103101

104-
inline void initTLSCurrentStreams() {
105-
if (!tls_streams_initialized) {
106-
tls_current_streams.resize(g_num_gpus);
107-
tls_streams_initialized = true;
108-
}
109-
}
110-
111102
inline phi::GPUContext* getMutableGPUContext(c10::DeviceIndex device_index) {
112103
return static_cast<phi::GPUContext*>(
113104
paddle::experimental::DeviceContextPool::Instance().GetMutable(
@@ -209,7 +200,6 @@ CUDAStream getCurrentCUDAStream(c10::DeviceIndex device_index) {
209200
static_cast<c10::DeviceIndex>(phi::backends::gpu::GetCurrentDeviceId());
210201
}
211202
check_gpu(device_index);
212-
initTLSCurrentStreams();
213203
auto raw = getPaddleCurrentStream(device_index);
214204
if (raw == nullptr) {
215205
return getDefaultCUDAStream(device_index);
@@ -225,16 +215,7 @@ void setCurrentCUDAStream(CUDAStream stream) {
225215
initGlobalState();
226216
c10::DeviceIndex idx = stream.unwrap().device_index();
227217
check_gpu(idx);
228-
initTLSCurrentStreams();
229-
auto& current_stream = tls_current_streams[idx];
230-
if (!current_stream) {
231-
current_stream =
232-
std::make_unique<phi::CUDAStream>(phi::GPUPlace(idx), stream.stream());
233-
} else {
234-
current_stream->set_raw_stream(stream.stream());
235-
}
236-
getMutableGPUContext(idx)->SetCUDAStream(current_stream.get(),
237-
/*clear=*/false);
218+
getMutableGPUContext(idx)->SetStream(stream.stream());
238219
#else
239220
(void)stream;
240221
#endif

0 commit comments

Comments
 (0)