@@ -50,8 +50,6 @@ struct DevicePools {
5050};
5151
5252std::vector<std::unique_ptr<DevicePools>> g_pools;
53- thread_local std::vector<std::unique_ptr<phi::CUDAStream>> tls_current_streams;
54- thread_local bool tls_streams_initialized = false ;
5553
5654void initGlobalState () {
5755 std::call_once (g_init_once, []() {
@@ -101,13 +99,6 @@ inline void check_gpu(c10::DeviceIndex device_index) {
10199 " )" );
102100}
103101
104- inline void initTLSCurrentStreams () {
105- if (!tls_streams_initialized) {
106- tls_current_streams.resize (g_num_gpus);
107- tls_streams_initialized = true ;
108- }
109- }
110-
111102inline phi::GPUContext* getMutableGPUContext (c10::DeviceIndex device_index) {
112103 return static_cast <phi::GPUContext*>(
113104 paddle::experimental::DeviceContextPool::Instance ().GetMutable (
@@ -209,7 +200,6 @@ CUDAStream getCurrentCUDAStream(c10::DeviceIndex device_index) {
209200 static_cast <c10::DeviceIndex>(phi::backends::gpu::GetCurrentDeviceId ());
210201 }
211202 check_gpu (device_index);
212- initTLSCurrentStreams ();
213203 auto raw = getPaddleCurrentStream (device_index);
214204 if (raw == nullptr ) {
215205 return getDefaultCUDAStream (device_index);
@@ -225,16 +215,7 @@ void setCurrentCUDAStream(CUDAStream stream) {
225215 initGlobalState ();
226216 c10::DeviceIndex idx = stream.unwrap ().device_index ();
227217 check_gpu (idx);
228- initTLSCurrentStreams ();
229- auto & current_stream = tls_current_streams[idx];
230- if (!current_stream) {
231- current_stream =
232- std::make_unique<phi::CUDAStream>(phi::GPUPlace (idx), stream.stream ());
233- } else {
234- current_stream->set_raw_stream (stream.stream ());
235- }
236- getMutableGPUContext (idx)->SetCUDAStream (current_stream.get (),
237- /* clear=*/ false );
218+ getMutableGPUContext (idx)->SetStream (stream.stream ());
238219#else
239220 (void )stream;
240221#endif
0 commit comments