diff --git a/paddle/phi/api/include/compat/ATen/ops/resize.h b/paddle/phi/api/include/compat/ATen/ops/resize.h index 44232e54ff935..5b6b22a623855 100644 --- a/paddle/phi/api/include/compat/ATen/ops/resize.h +++ b/paddle/phi/api/include/compat/ATen/ops/resize.h @@ -16,20 +16,90 @@ #include #include +#include #include #include +#include #include "paddle/phi/api/include/api.h" +#include "paddle/phi/common/memory_utils.h" namespace at { -// resize_ - in-place resize using reshape +namespace detail { + +inline int64_t ResizeCheckedNumel(at::IntArrayRef size) { + int64_t numel = 1; + for (const auto dim : size) { + TORCH_CHECK(dim >= 0, + "Trying to create tensor with negative dimension ", + dim, + ": ", + size); + if (dim == 0) { + numel = 0; + continue; + } + TORCH_CHECK(numel <= std::numeric_limits::max() / dim, + "resize_ size is too large, possible overflow for size ", + size); + numel *= dim; + } + return numel; +} + +} // namespace detail + +// resize_ - operate on the underlying DenseTensor directly so we preserve +// storage semantics across shrink/grow round-trips and only reallocate when +// the requested shape exceeds the current storage capacity. inline const at::Tensor& Tensor::resize_( at::IntArrayRef size, ::std::optional memory_format) const { - auto result = - paddle::experimental::reshape(tensor_, size._PD_ToPaddleIntArray()); - const_cast(this)->tensor_ = result; + // Keep old compat behavior for memory_format in this split PR. + // TODO(youge325): add real ChannelsLast/ChannelsLast3d restride support + // later. + (void)memory_format; + + std::vector dims(size.begin(), size.end()); + int64_t new_numel = detail::ResizeCheckedNumel(size); + auto dense_tensor = + std::dynamic_pointer_cast(tensor_.impl()); + TORCH_CHECK(dense_tensor != nullptr, + "resize_ only supports DenseTensor, but got a non-dense tensor"); + TORCH_CHECK(tensor_.defined(), + "resize_ is not allowed on an undefined tensor"); + + const size_t itemsize = phi::SizeOf(dense_tensor->dtype()); + const size_t old_numel = static_cast(tensor_.numel()); + const size_t new_numel_size = static_cast(new_numel); + const size_t required_bytes = new_numel_size * itemsize; + const size_t available_bytes = + dense_tensor->Holder() == nullptr + ? 0 + : dense_tensor->Holder()->size() - dense_tensor->offset(); + + if (required_bytes <= available_bytes || new_numel == 0) { + dense_tensor->Resize(dims); + return *this; + } + + const auto old_holder = dense_tensor->Holder(); + TORCH_CHECK(old_holder != nullptr, + "resize_ cannot grow a tensor without allocated storage"); + const size_t old_offset = dense_tensor->offset(); + const size_t copy_bytes = std::min(old_numel, new_numel_size) * itemsize; + const phi::Place place = old_holder->place(); + const void* old_data = + old_holder == nullptr + ? nullptr + : reinterpret_cast(old_holder->ptr()) + old_offset; + + dense_tensor->Resize(dims); + void* new_data = dense_tensor->mutable_data(place, dense_tensor->dtype()); + if (copy_bytes > 0 && old_data != nullptr && old_data != new_data) { + phi::memory_utils::Copy(place, new_data, place, old_data, copy_bytes); + } return *this; } diff --git a/test/cpp/compat/ATen_resize_test.cc b/test/cpp/compat/ATen_resize_test.cc index 0ad6b09f58eb2..82793bbf9a544 100644 --- a/test/cpp/compat/ATen_resize_test.cc +++ b/test/cpp/compat/ATen_resize_test.cc @@ -19,13 +19,17 @@ #include #include +#include +#include + #include "ATen/ATen.h" #include "gtest/gtest.h" #include "torch/all.h" // ======================== resize_ tests ======================== -// Note: Paddle's resize_ is implemented via reshape, which requires -// total element count to remain unchanged. +// Note: compat resize_ mutates the underlying DenseTensor directly so +// shrink/grow round-trips preserve storage semantics without introducing new +// memory_format hard errors in this split PR. TEST(TensorResizeTest, ResizeBasic) { // Create a 2x3 tensor @@ -109,6 +113,92 @@ TEST(TensorResizeTest, ResizePreservesData) { ASSERT_FLOAT_EQ(data[5], 5.0f); } +TEST(TensorResizeTest, ResizeShrinkDifferentNumel) { + at::Tensor t = at::arange(24, at::kFloat).reshape({2, 3, 4}); + + t.resize_({4, 5}); + + ASSERT_EQ(t.sizes()[0], 4); + ASSERT_EQ(t.sizes()[1], 5); + + float* data = t.data_ptr(); + for (int i = 0; i < 20; ++i) { + ASSERT_FLOAT_EQ(data[i], static_cast(i)); + } +} + +TEST(TensorResizeTest, ResizeGrowDifferentNumelPreservesPrefix) { + at::Tensor t = at::arange(6, at::kFloat).reshape({2, 3}); + + t.resize_({2, 5}); + + ASSERT_EQ(t.sizes()[0], 2); + ASSERT_EQ(t.sizes()[1], 5); + + float* data = t.data_ptr(); + for (int i = 0; i < 6; ++i) { + ASSERT_FLOAT_EQ(data[i], static_cast(i)); + } +} + +TEST(TensorResizeTest, ResizeShrinkGrowRoundTripPreservesTail) { + at::Tensor t = at::arange(24, at::kFloat).reshape({2, 3, 4}); + + t.resize_({4, 5}); + t.resize_({2, 3, 4}); + + ASSERT_EQ(t.sizes()[0], 2); + ASSERT_EQ(t.sizes()[1], 3); + ASSERT_EQ(t.sizes()[2], 4); + + float* data = t.data_ptr(); + for (int i = 0; i < 24; ++i) { + ASSERT_FLOAT_EQ(data[i], static_cast(i)); + } +} + +TEST(TensorResizeTest, ResizeChannelsLastMemoryFormatDoesNotThrow) { + at::Tensor t = at::arange(24, at::kFloat).reshape({1, 2, 3, 4}); + + EXPECT_NO_THROW({ + t.resize_(std::vector{1, 3, 2, 4}, at::MemoryFormat::ChannelsLast); + }); + + ASSERT_EQ(t.sizes()[0], 1); + ASSERT_EQ(t.sizes()[1], 3); + ASSERT_EQ(t.sizes()[2], 2); + ASSERT_EQ(t.sizes()[3], 4); +} + +TEST(TensorResizeTest, ResizeChannelsLast3dMemoryFormatDoesNotThrow) { + at::Tensor t = at::arange(24, at::kFloat).reshape({1, 2, 2, 2, 3}); + + EXPECT_NO_THROW({ + t.resize_(std::vector{1, 2, 2, 3, 2}, + at::MemoryFormat::ChannelsLast3d); + }); + + ASSERT_EQ(t.sizes()[0], 1); + ASSERT_EQ(t.sizes()[1], 2); + ASSERT_EQ(t.sizes()[2], 2); + ASSERT_EQ(t.sizes()[3], 3); + ASSERT_EQ(t.sizes()[4], 2); +} + +TEST(TensorResizeTest, ResizeRejectsNegativeDimension) { + at::Tensor t = at::arange(6, at::kFloat); + auto bad_size = std::vector{2, -1}; + + EXPECT_THROW(t.resize_(bad_size), std::exception); +} + +TEST(TensorResizeTest, ResizeRejectsNumelOverflow) { + at::Tensor t = at::arange(1, at::kFloat); + auto huge_size = std::vector{std::numeric_limits::max(), 2}; + + EXPECT_THROW(t.resize_(huge_size), std::exception); +} + TEST(TensorResizeTest, ResizeReturnReference) { // Create a tensor at::Tensor t = at::zeros({2, 3});