-
Notifications
You must be signed in to change notification settings - Fork 6k
[Cpp API Compatibility] Align resize api #78554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -23,13 +23,31 @@ | |||||
|
|
||||||
| namespace at { | ||||||
|
|
||||||
| // resize_ - in-place resize using reshape | ||||||
| // resize_ - use reshape for same-numel cases and set_ for storage-changing | ||||||
| // cases so repeated resize_ calls stay stable. | ||||||
| inline const at::Tensor& Tensor::resize_( | ||||||
| at::IntArrayRef size, | ||||||
| ::std::optional<at::MemoryFormat> memory_format) const { | ||||||
| auto result = | ||||||
| paddle::experimental::reshape(tensor_, size._PD_ToPaddleIntArray()); | ||||||
| const_cast<Tensor*>(this)->tensor_ = result; | ||||||
| if (memory_format.has_value()) { | ||||||
| TORCH_CHECK(*memory_format == at::MemoryFormat::Contiguous, | ||||||
| "resize_ only supports contiguous memory format, but got ", | ||||||
| static_cast<int>(*memory_format)); | ||||||
|
||||||
| static_cast<int>(*memory_format)); | |
| *memory_format); |
Outdated
Copilot
AI
Apr 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new_numel is computed by multiplying dims without any checks for negative dimensions or int64 overflow. If a caller passes an invalid/very large size, new_numel can wrap and the code may take the reshape path or pass sizes into set_ in a surprising way. Consider validating dim >= 0 and using overflow-safe multiplication (or an existing checked helper) before comparing against tensor_.numel().
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ShigureNyako 给点详细的 suggestion 吧,通过 code suggestion 提出来,这里具体在兼容性的风险是什么?是否会增加很多复杂性?如果在可控范围内可以按照你说的改
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,8 +24,8 @@ | |
| #include "torch/all.h" | ||
|
|
||
| // ======================== resize_ tests ======================== | ||
| // Note: Paddle's resize_ is implemented via reshape, which requires | ||
| // total element count to remain unchanged. | ||
| // Note: compat resize_ uses reshape when numel is unchanged, and falls back to | ||
| // set_ for storage-changing cases so repeated resize_ calls remain stable. | ||
|
|
||
| TEST(TensorResizeTest, ResizeBasic) { | ||
| // Create a 2x3 tensor | ||
|
|
@@ -109,6 +109,34 @@ TEST(TensorResizeTest, ResizePreservesData) { | |
| ASSERT_FLOAT_EQ(data[5], 5.0f); | ||
| } | ||
|
|
||
| TEST(TensorResizeTest, ResizeShrinkDifferentNumel) { | ||
| at::Tensor t = at::arange(24, at::kFloat).reshape({2, 3, 4}); | ||
|
|
||
| t.resize_({4, 5}); | ||
|
|
||
|
Comment on lines
+116
to
+120
|
||
| ASSERT_EQ(t.sizes()[0], 4); | ||
| ASSERT_EQ(t.sizes()[1], 5); | ||
|
|
||
| float* data = t.data_ptr<float>(); | ||
| for (int i = 0; i < 20; ++i) { | ||
| ASSERT_FLOAT_EQ(data[i], static_cast<float>(i)); | ||
| } | ||
| } | ||
|
|
||
| TEST(TensorResizeTest, ResizeGrowDifferentNumelPreservesPrefix) { | ||
| at::Tensor t = at::arange(6, at::kFloat).reshape({2, 3}); | ||
|
|
||
| t.resize_({2, 5}); | ||
|
|
||
| ASSERT_EQ(t.sizes()[0], 2); | ||
| ASSERT_EQ(t.sizes()[1], 5); | ||
|
|
||
| float* data = t.data_ptr<float>(); | ||
| for (int i = 0; i < 6; ++i) { | ||
| ASSERT_FLOAT_EQ(data[i], static_cast<float>(i)); | ||
| } | ||
| } | ||
|
|
||
| TEST(TensorResizeTest, ResizeReturnReference) { | ||
| // Create a tensor | ||
| at::Tensor t = at::zeros({2, 3}); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯?相比于之前不是 breaking 吧?之前应该 Paddle 本身也没支持,会有更不友好的报错,这里报错是合理的