Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions paddle/phi/api/include/compat/ATen/ops/resize.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,31 @@

namespace at {

// resize_ - in-place resize using reshape
// resize_ - use reshape for same-numel cases and set_ for storage-changing
// cases so repeated resize_ calls stay stable.
inline const at::Tensor& Tensor::resize_(
at::IntArrayRef size,
::std::optional<at::MemoryFormat> memory_format) const {
auto result =
paddle::experimental::reshape(tensor_, size._PD_ToPaddleIntArray());
const_cast<Tensor*>(this)->tensor_ = result;
if (memory_format.has_value()) {
TORCH_CHECK(*memory_format == at::MemoryFormat::Contiguous,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯?相比于之前不是 breaking 吧?之前应该 Paddle 本身也没支持,会有更不友好的报错,这里报错是合理的

"resize_ only supports contiguous memory format, but got ",
static_cast<int>(*memory_format));
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The memory_format rejection message reports static_cast<int>(*memory_format), which is hard to interpret and inconsistent with other compat ops that report unsupported MemoryFormat values more directly. Prefer emitting a clearer value (e.g., the enum name) and/or reusing the same phrasing used in other ATen compat ops (like empty/empty_like) so downstream users get consistent diagnostics.

Suggested change
static_cast<int>(*memory_format));
*memory_format);

Copilot uses AI. Check for mistakes.
}

std::vector<int64_t> dims(size.begin(), size.end());
int64_t new_numel = 1;
for (auto dim : dims) {
new_numel *= dim;
}
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_numel is computed by multiplying dims without any checks for negative dimensions or int64 overflow. If a caller passes an invalid/very large size, new_numel can wrap and the code may take the reshape path or pass sizes into set_ in a surprising way. Consider validating dim >= 0 and using overflow-safe multiplication (or an existing checked helper) before comparing against tensor_.numel().

Copilot uses AI. Check for mistakes.

if (tensor_.numel() == new_numel) {
const_cast<Tensor*>(this)->tensor_ =
paddle::experimental::reshape(tensor_, phi::IntArray(dims));
return *this;
}

auto source = tensor_.copy_to(tensor_.place(), /*blocking=*/true);
paddle::experimental::set_(const_cast<Tensor*>(this)->tensor_, source, dims);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ShigureNyako 给点详细的 suggestion 吧,通过 code suggestion 提出来,这里具体在兼容性的风险是什么?是否会增加很多复杂性?如果在可控范围内可以按照你说的改

return *this;
}

Expand Down
32 changes: 30 additions & 2 deletions test/cpp/compat/ATen_resize_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
#include "torch/all.h"

// ======================== resize_ tests ========================
// Note: Paddle's resize_ is implemented via reshape, which requires
// total element count to remain unchanged.
// Note: compat resize_ uses reshape when numel is unchanged, and falls back to
// set_ for storage-changing cases so repeated resize_ calls remain stable.

TEST(TensorResizeTest, ResizeBasic) {
// Create a 2x3 tensor
Expand Down Expand Up @@ -109,6 +109,34 @@ TEST(TensorResizeTest, ResizePreservesData) {
ASSERT_FLOAT_EQ(data[5], 5.0f);
}

TEST(TensorResizeTest, ResizeShrinkDifferentNumel) {
at::Tensor t = at::arange(24, at::kFloat).reshape({2, 3, 4});

t.resize_({4, 5});

Comment on lines +116 to +120
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New behavior was added for the memory_format argument (only contiguous is accepted), but the test suite here doesn't cover that path. Adding a small gtest that passes a non-contiguous memory format and asserts the expected failure would lock in the intended compatibility behavior.

Copilot uses AI. Check for mistakes.
ASSERT_EQ(t.sizes()[0], 4);
ASSERT_EQ(t.sizes()[1], 5);

float* data = t.data_ptr<float>();
for (int i = 0; i < 20; ++i) {
ASSERT_FLOAT_EQ(data[i], static_cast<float>(i));
}
}

TEST(TensorResizeTest, ResizeGrowDifferentNumelPreservesPrefix) {
at::Tensor t = at::arange(6, at::kFloat).reshape({2, 3});

t.resize_({2, 5});

ASSERT_EQ(t.sizes()[0], 2);
ASSERT_EQ(t.sizes()[1], 5);

float* data = t.data_ptr<float>();
for (int i = 0; i < 6; ++i) {
ASSERT_FLOAT_EQ(data[i], static_cast<float>(i));
}
}

TEST(TensorResizeTest, ResizeReturnReference) {
// Create a tensor
at::Tensor t = at::zeros({2, 3});
Expand Down
Loading