diff --git a/aten/src/ATen/native/cuda/CUDAScalar.cu b/aten/src/ATen/native/cuda/CUDAScalar.cu index 169a2ab92615f..8b75bf3bf4821 100644 --- a/aten/src/ATen/native/cuda/CUDAScalar.cu +++ b/aten/src/ATen/native/cuda/CUDAScalar.cu @@ -18,6 +18,14 @@ Scalar _local_scalar_dense_cuda(const Tensor& self) { TORCH_CHECK(self.numel() > 0, "_local_scalar_dense: Empty tensor not supported"); AT_DISPATCH_V2( self.scalar_type(), "_local_scalar_dense_cuda", AT_WRAP([&] { +#ifdef USE_ROCM + // If this is a large BAR device, we can just read directly from VRAM + if (at::cuda::getCurrentDeviceProperties()->isLargeBar) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + at::cuda::stream_synchronize(stream); + r = Scalar(*self.const_data_ptr()); + } else { +#endif // Create pinned memory for the scalar value to avoid implicit // locking/sync in cuda library due to pageable memory auto value = at::detail::empty_cpu( @@ -31,6 +39,9 @@ Scalar _local_scalar_dense_cuda(const Tensor& self) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); at::cuda::memcpy_and_sync(value.mutable_data_ptr(), self.const_data_ptr(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream); r = Scalar(*value.const_data_ptr()); +#ifdef USE_ROCM + } +#endif }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); return r; }