Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions backends/aoti/slim/core/slim_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,13 +433,19 @@ class SlimTensor {
/**
* Copy data from another tensor to this tensor.
*
* Both tensors must have the same numel and dtype.
* Currently only supports CPU-to-CPU copy (contiguous tensors only).
* Both tensors must have the same numel, sizes and dtype.
*
* @param other The source tensor to copy from
* @return Reference to this tensor
*/
SlimTensor& copy_(const SlimTensor& other) {
ET_CHECK_MSG(
this->dim() == other.dim(),
"copy_: dim of tensors must match (%zu vs %zu)",
this->dim(),
Comment on lines 441 to +445
other.dim());
ET_CHECK_MSG(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check also enforces two tensors having same shape -- can we remove it?

this->sizes() == other.sizes(), "copy_: sizes of tensors must match");
ET_CHECK_MSG(
this->numel() == other.numel(), "copy_: numel of tensors must match");
ET_CHECK_MSG(this->dtype() == other.dtype(), "copy_: dtype must match");
Comment on lines +447 to 451
Expand All @@ -463,29 +469,43 @@ class SlimTensor {

std::vector<int64_t> counter(this->dim(), 0);
for (size_t i = 0; i < this->numel(); i++) {
// Compute src offset in elements
int64_t src_offset = 0;
for (size_t d = 0; d < other.dim(); d++) {
src_offset += counter[d] * other.stride(d);
}

// Compute dst offset in elements
int64_t dst_offset = 0;
for (size_t d = 0; d < this->dim(); d++) {
dst_offset += counter[d] * this->stride(d);
int64_t src_term = 0;
int64_t dst_term = 0;
// src_offset = src_offset + counter[d] * other.stride(d)
// dst_offset = dst_offset + counter[d] * this->stride(d)
ET_CHECK_MSG(
!::c10::mul_overflows(counter[d], other.stride(d), &src_term) &&
!::c10::add_overflows(src_offset, src_term, &src_offset) &&
!::c10::mul_overflows(counter[d], this->stride(d), &dst_term) &&
!::c10::add_overflows(dst_offset, dst_term, &dst_offset),
"copy_: offset computation overflow");
}
size_t src_byte_offset = 0;
size_t dst_byte_offset = 0;
// src_byte_offset = src_offset * elem_size
// dst_byte_offset = dst_offset * elem_size
ET_CHECK_MSG(
src_offset >= 0 && dst_offset >= 0 &&
!::c10::mul_overflows(
static_cast<size_t>(src_offset),
elem_size,
&src_byte_offset) &&
!::c10::mul_overflows(
static_cast<size_t>(dst_offset), elem_size, &dst_byte_offset),
"copy_: byte offset overflow");

// Copy elem_size bytes from src to dst
if (this->device().is_cpu() && other.device().is_cpu()) {
std::memcpy(
dst_data + dst_offset * elem_size,
src_data + src_offset * elem_size,
elem_size);
dst_data + dst_byte_offset, src_data + src_byte_offset, elem_size);
} else if (this->device().is_cuda() || other.device().is_cuda()) {
Comment on lines 500 to 504
#if defined(CUDA_AVAILABLE)
DeviceTraits<c10::DeviceType::CUDA>::memcpy(
dst_data + dst_offset * elem_size,
src_data + src_offset * elem_size,
dst_data + dst_byte_offset,
src_data + src_byte_offset,
elem_size,
device(), // dst device
other.device() // src device
Expand Down
Loading