Skip to content

Commit 2893705

Browse files
authored
[Cpp API Compatibility] Align resize api (#78554)
1 parent 4f0218b commit 2893705

File tree

2 files changed

+166
-6
lines changed

2 files changed

+166
-6
lines changed

paddle/phi/api/include/compat/ATen/ops/resize.h

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,90 @@
1616

1717
#include <ATen/core/Tensor.h>
1818
#include <c10/core/TensorOptions.h>
19+
#include <limits>
1920
#include <optional>
2021
#include <string_view>
22+
#include <vector>
2123

2224
#include "paddle/phi/api/include/api.h"
25+
#include "paddle/phi/common/memory_utils.h"
2326

2427
namespace at {
2528

26-
// resize_ - in-place resize using reshape
29+
namespace detail {
30+
31+
inline int64_t ResizeCheckedNumel(at::IntArrayRef size) {
32+
int64_t numel = 1;
33+
for (const auto dim : size) {
34+
TORCH_CHECK(dim >= 0,
35+
"Trying to create tensor with negative dimension ",
36+
dim,
37+
": ",
38+
size);
39+
if (dim == 0) {
40+
numel = 0;
41+
continue;
42+
}
43+
TORCH_CHECK(numel <= std::numeric_limits<int64_t>::max() / dim,
44+
"resize_ size is too large, possible overflow for size ",
45+
size);
46+
numel *= dim;
47+
}
48+
return numel;
49+
}
50+
51+
} // namespace detail
52+
53+
// resize_ - operate on the underlying DenseTensor directly so we preserve
54+
// storage semantics across shrink/grow round-trips and only reallocate when
55+
// the requested shape exceeds the current storage capacity.
2756
inline const at::Tensor& Tensor::resize_(
2857
at::IntArrayRef size,
2958
::std::optional<at::MemoryFormat> memory_format) const {
30-
auto result =
31-
paddle::experimental::reshape(tensor_, size._PD_ToPaddleIntArray());
32-
const_cast<Tensor*>(this)->tensor_ = result;
59+
// Keep old compat behavior for memory_format in this split PR.
60+
// TODO(youge325): add real ChannelsLast/ChannelsLast3d restride support
61+
// later.
62+
(void)memory_format;
63+
64+
std::vector<int64_t> dims(size.begin(), size.end());
65+
int64_t new_numel = detail::ResizeCheckedNumel(size);
66+
auto dense_tensor =
67+
std::dynamic_pointer_cast<phi::DenseTensor>(tensor_.impl());
68+
TORCH_CHECK(dense_tensor != nullptr,
69+
"resize_ only supports DenseTensor, but got a non-dense tensor");
70+
TORCH_CHECK(tensor_.defined(),
71+
"resize_ is not allowed on an undefined tensor");
72+
73+
const size_t itemsize = phi::SizeOf(dense_tensor->dtype());
74+
const size_t old_numel = static_cast<size_t>(tensor_.numel());
75+
const size_t new_numel_size = static_cast<size_t>(new_numel);
76+
const size_t required_bytes = new_numel_size * itemsize;
77+
const size_t available_bytes =
78+
dense_tensor->Holder() == nullptr
79+
? 0
80+
: dense_tensor->Holder()->size() - dense_tensor->offset();
81+
82+
if (required_bytes <= available_bytes || new_numel == 0) {
83+
dense_tensor->Resize(dims);
84+
return *this;
85+
}
86+
87+
const auto old_holder = dense_tensor->Holder();
88+
TORCH_CHECK(old_holder != nullptr,
89+
"resize_ cannot grow a tensor without allocated storage");
90+
const size_t old_offset = dense_tensor->offset();
91+
const size_t copy_bytes = std::min(old_numel, new_numel_size) * itemsize;
92+
const phi::Place place = old_holder->place();
93+
const void* old_data =
94+
old_holder == nullptr
95+
? nullptr
96+
: reinterpret_cast<const uint8_t*>(old_holder->ptr()) + old_offset;
97+
98+
dense_tensor->Resize(dims);
99+
void* new_data = dense_tensor->mutable_data(place, dense_tensor->dtype());
100+
if (copy_bytes > 0 && old_data != nullptr && old_data != new_data) {
101+
phi::memory_utils::Copy(place, new_data, place, old_data, copy_bytes);
102+
}
33103
return *this;
34104
}
35105

test/cpp/compat/ATen_resize_test.cc

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,17 @@
1919
#include <c10/core/ScalarType.h>
2020
#include <c10/core/TensorOptions.h>
2121

22+
#include <limits>
23+
#include <vector>
24+
2225
#include "ATen/ATen.h"
2326
#include "gtest/gtest.h"
2427
#include "torch/all.h"
2528

2629
// ======================== resize_ tests ========================
27-
// Note: Paddle's resize_ is implemented via reshape, which requires
28-
// total element count to remain unchanged.
30+
// Note: compat resize_ mutates the underlying DenseTensor directly so
31+
// shrink/grow round-trips preserve storage semantics without introducing new
32+
// memory_format hard errors in this split PR.
2933

3034
TEST(TensorResizeTest, ResizeBasic) {
3135
// Create a 2x3 tensor
@@ -109,6 +113,92 @@ TEST(TensorResizeTest, ResizePreservesData) {
109113
ASSERT_FLOAT_EQ(data[5], 5.0f);
110114
}
111115

116+
TEST(TensorResizeTest, ResizeShrinkDifferentNumel) {
117+
at::Tensor t = at::arange(24, at::kFloat).reshape({2, 3, 4});
118+
119+
t.resize_({4, 5});
120+
121+
ASSERT_EQ(t.sizes()[0], 4);
122+
ASSERT_EQ(t.sizes()[1], 5);
123+
124+
float* data = t.data_ptr<float>();
125+
for (int i = 0; i < 20; ++i) {
126+
ASSERT_FLOAT_EQ(data[i], static_cast<float>(i));
127+
}
128+
}
129+
130+
TEST(TensorResizeTest, ResizeGrowDifferentNumelPreservesPrefix) {
131+
at::Tensor t = at::arange(6, at::kFloat).reshape({2, 3});
132+
133+
t.resize_({2, 5});
134+
135+
ASSERT_EQ(t.sizes()[0], 2);
136+
ASSERT_EQ(t.sizes()[1], 5);
137+
138+
float* data = t.data_ptr<float>();
139+
for (int i = 0; i < 6; ++i) {
140+
ASSERT_FLOAT_EQ(data[i], static_cast<float>(i));
141+
}
142+
}
143+
144+
TEST(TensorResizeTest, ResizeShrinkGrowRoundTripPreservesTail) {
145+
at::Tensor t = at::arange(24, at::kFloat).reshape({2, 3, 4});
146+
147+
t.resize_({4, 5});
148+
t.resize_({2, 3, 4});
149+
150+
ASSERT_EQ(t.sizes()[0], 2);
151+
ASSERT_EQ(t.sizes()[1], 3);
152+
ASSERT_EQ(t.sizes()[2], 4);
153+
154+
float* data = t.data_ptr<float>();
155+
for (int i = 0; i < 24; ++i) {
156+
ASSERT_FLOAT_EQ(data[i], static_cast<float>(i));
157+
}
158+
}
159+
160+
TEST(TensorResizeTest, ResizeChannelsLastMemoryFormatDoesNotThrow) {
161+
at::Tensor t = at::arange(24, at::kFloat).reshape({1, 2, 3, 4});
162+
163+
EXPECT_NO_THROW({
164+
t.resize_(std::vector<int64_t>{1, 3, 2, 4}, at::MemoryFormat::ChannelsLast);
165+
});
166+
167+
ASSERT_EQ(t.sizes()[0], 1);
168+
ASSERT_EQ(t.sizes()[1], 3);
169+
ASSERT_EQ(t.sizes()[2], 2);
170+
ASSERT_EQ(t.sizes()[3], 4);
171+
}
172+
173+
TEST(TensorResizeTest, ResizeChannelsLast3dMemoryFormatDoesNotThrow) {
174+
at::Tensor t = at::arange(24, at::kFloat).reshape({1, 2, 2, 2, 3});
175+
176+
EXPECT_NO_THROW({
177+
t.resize_(std::vector<int64_t>{1, 2, 2, 3, 2},
178+
at::MemoryFormat::ChannelsLast3d);
179+
});
180+
181+
ASSERT_EQ(t.sizes()[0], 1);
182+
ASSERT_EQ(t.sizes()[1], 2);
183+
ASSERT_EQ(t.sizes()[2], 2);
184+
ASSERT_EQ(t.sizes()[3], 3);
185+
ASSERT_EQ(t.sizes()[4], 2);
186+
}
187+
188+
TEST(TensorResizeTest, ResizeRejectsNegativeDimension) {
189+
at::Tensor t = at::arange(6, at::kFloat);
190+
auto bad_size = std::vector<int64_t>{2, -1};
191+
192+
EXPECT_THROW(t.resize_(bad_size), std::exception);
193+
}
194+
195+
TEST(TensorResizeTest, ResizeRejectsNumelOverflow) {
196+
at::Tensor t = at::arange(1, at::kFloat);
197+
auto huge_size = std::vector<int64_t>{std::numeric_limits<int64_t>::max(), 2};
198+
199+
EXPECT_THROW(t.resize_(huge_size), std::exception);
200+
}
201+
112202
TEST(TensorResizeTest, ResizeReturnReference) {
113203
// Create a tensor
114204
at::Tensor t = at::zeros({2, 3});

0 commit comments

Comments
 (0)