Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions paddle/phi/api/include/compat/ATen/ops/equal.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
namespace at {

inline bool equal(const at::Tensor& self, const at::Tensor& other) {
PD_CHECK(self.defined(),
"Expected a proper Tensor but got None (or an undefined Tensor in "
"C++)");
PD_CHECK(other.defined(),
"Expected a proper Tensor but got None (or an undefined Tensor in "
"C++)");
PD_CHECK(self.device() == other.device(),
"Cannot compare two tensors on "
"different devices. Got: ",
Expand Down
28 changes: 25 additions & 3 deletions paddle/phi/api/include/compat/ATen/ops/select.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,35 @@
namespace at {

inline at::Tensor select(const at::Tensor& self, int64_t dim, int64_t index) {
// Normalize dim to positive value for error messages
int64_t orig_dim = dim;
if (dim < 0) {
dim += self.dim();
}
// Handle negative indexing
// Check dim is valid
if (dim < 0 || dim >= self.dim()) {
PD_CHECK(false,
"select(): index ",
orig_dim,
" out of range for tensor of size ",
self.sizes(),
" at dimension ",
orig_dim);
}
// Handle negative index
int64_t orig_index = index;
if (index < 0) {
int64_t dim_size = self.size(dim);
index = dim_size + index;
index = self.size(dim) + index;
}
// Check index is valid
if (index < 0 || index >= self.size(dim)) {
PD_CHECK(false,
"select(): index ",
orig_index,
" out of range for tensor of size ",
self.sizes(),
" at dimension ",
orig_dim < 0 ? orig_dim + self.dim() : orig_dim);
}
Comment on lines +28 to 51
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New error handling was added for invalid dim / index values, but the existing select tests (e.g., test/cpp/compat/ATen_select_test.cc) don't cover these out-of-range branches. Add test cases that assert an exception is thrown for (1) dim out of range (including negative beyond -self.dim()), and (2) index out of range (including negative beyond -size(dim)).

Copilot uses AI. Check for mistakes.

return Tensor(
Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/api/include/compat/ATen/ops/std.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ inline Tensor std_impl(const Tensor& self,
const std::vector<int64_t>& dims_vec,
double correction_value,
bool keepdim) {
// Validate dimensions before processing
int64_t ndim = self.dim();
for (int64_t d : dims_vec) {
int64_t dim_idx = d < 0 ? d + ndim : d;
if (dim_idx < 0 || dim_idx >= ndim) {
PD_CHECK(false,
"Dimension out of range (expected to be in range of [",
-ndim,
", ",
ndim - 1,
"], but got ",
d,
")");
}
}
phi::IntArray dims_int_array(dims_vec);
paddle::Tensor tensor = self._PD_GetInner();

Expand Down
48 changes: 35 additions & 13 deletions paddle/phi/api/include/compat/c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,23 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
#undef DEFINE_CONSTANT

constexpr ScalarType kComplexHalf = ScalarType::ComplexHalf;
constexpr ScalarType kQInt8 = ScalarType::QInt8;
constexpr ScalarType kQUInt8 = ScalarType::QUInt8;
constexpr ScalarType kQInt32 = ScalarType::QInt32;
constexpr ScalarType kQUInt4x2 = ScalarType::QUInt4x2;
constexpr ScalarType kQUInt2x4 = ScalarType::QUInt2x4;
constexpr ScalarType kBits1x8 = ScalarType::Bits1x8;
constexpr ScalarType kBits2x4 = ScalarType::Bits2x4;
constexpr ScalarType kBits4x2 = ScalarType::Bits4x2;
constexpr ScalarType kBits8 = ScalarType::Bits8;
constexpr ScalarType kBits16 = ScalarType::Bits16;
constexpr ScalarType kFloat8_e5m2fnuz = ScalarType::Float8_e5m2fnuz;
constexpr ScalarType kFloat8_e4m3fnuz = ScalarType::Float8_e4m3fnuz;
constexpr ScalarType kFloat8_e8m0fnu = ScalarType::Float8_e8m0fnu;
constexpr ScalarType kFloat4_e2m1fn_x2 = ScalarType::Float4_e2m1fn_x2;
constexpr ScalarType kUndefined = ScalarType::Undefined;

#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
Expand Down Expand Up @@ -281,6 +298,8 @@ inline const char* toString(ScalarType t) {

switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
case ScalarType::ComplexHalf:
return "ComplexHalf";
case ScalarType::QInt8:
return "QInt8";
case ScalarType::QUInt8:
Expand All @@ -291,8 +310,6 @@ inline const char* toString(ScalarType t) {
return "QUInt4x2";
case ScalarType::QUInt2x4:
return "QUInt2x4";
case ScalarType::ComplexHalf:
return "ComplexHalf";
case ScalarType::Bits1x8:
return "Bits1x8";
case ScalarType::Bits2x4:
Expand Down Expand Up @@ -326,6 +343,8 @@ inline size_t elementSize(ScalarType t) {

switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE)
case ScalarType::ComplexHalf:
return sizeof(at::Half) * 2;
case ScalarType::QInt8:
case ScalarType::QUInt8:
case ScalarType::QUInt4x2:
Expand All @@ -334,10 +353,15 @@ inline size_t elementSize(ScalarType t) {
case ScalarType::Bits2x4:
case ScalarType::Bits4x2:
case ScalarType::Bits8:
case ScalarType::Float8_e5m2fnuz:
case ScalarType::Float8_e4m3fnuz:
case ScalarType::Float8_e8m0fnu:
case ScalarType::Float4_e2m1fn_x2:
return 1;
case ScalarType::QInt32:
case ScalarType::Bits16:
return 4;
case ScalarType::Bits16:
return 2;
default:
TORCH_CHECK(false, "Unknown ScalarType");
}
Expand All @@ -354,15 +378,14 @@ inline bool isIntegralType(ScalarType t, bool includeBool) {
}

inline bool isFloat8Type(ScalarType t) {
return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e4m3fn;
// || t == ScalarType::Float8_e5m2fnuz
// || t == ScalarType::Float8_e4m3fnuz
// || t == ScalarType::Float8_e8m0fnu
return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e4m3fn ||
t == ScalarType::Float8_e5m2fnuz || t == ScalarType::Float8_e4m3fnuz ||
t == ScalarType::Float8_e8m0fnu;
}

inline bool isReducedFloatingType(ScalarType t) {
return t == ScalarType::Half || t == ScalarType::BFloat16 || isFloat8Type(t);
//|| t == ScalarType::Float4_e2m1fn_x2
return t == ScalarType::Half || t == ScalarType::BFloat16 ||
isFloat8Type(t) || t == ScalarType::Float4_e2m1fn_x2;
}

inline bool isFloatingType(ScalarType t) {
Expand All @@ -371,9 +394,8 @@ inline bool isFloatingType(ScalarType t) {
}

inline bool isComplexType(ScalarType t) {
return (
/* t == ScalarType::ComplexHalf || */ t == ScalarType::ComplexFloat ||
t == ScalarType::ComplexDouble);
return (t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat ||
t == ScalarType::ComplexDouble);
}

inline bool isSignedType(ScalarType t) {
Expand Down Expand Up @@ -408,9 +430,9 @@ inline bool isSignedType(ScalarType t) {
CASE_ISSIGNED(Float8_e4m3fn);

// Complex types (treated as signed)
case ScalarType::ComplexHalf:
case ScalarType::ComplexFloat:
case ScalarType::ComplexDouble:
case ScalarType::ComplexHalf:
return true;

// Signed quantized types (explicitly return true)
Expand Down
8 changes: 5 additions & 3 deletions paddle/phi/api/include/compat/c10/core/Stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ void* Stream::native_handle() const {
return reinterpret_cast<void*>(static_cast<intptr_t>(id_));
}
#endif
PADDLE_THROW(::common::errors::Unimplemented(
"c10::Stream::native_handle() is not supported for device type %d",
static_cast<int>(device_type())));
// Match PyTorch error message format for unsupported device types
PD_CHECK(false,
"native_handle() is not supported for this device type (",
Comment on lines +48 to +49
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PD_CHECK(false, ...) will throw a PD_Exception that appends an additional "Expected false, but it's not satisfied." context. If the goal is to match PyTorch's error message format exactly (as the comment states), this will not produce an exact match. Prefer throwing directly (e.g., TORCH_CHECK(false, ...) / PD_THROW(...)) or pass the real condition into PD_CHECK(...) instead of false so the extra context remains meaningful.

Suggested change
PD_CHECK(false,
"native_handle() is not supported for this device type (",
PD_THROW("native_handle() is not supported for this device type (",

Copilot uses AI. Check for mistakes.
static_cast<int>(device_type()),
")");
}

bool Stream::query() const {
Expand Down
5 changes: 2 additions & 3 deletions paddle/phi/api/include/compat/c10/core/Stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,8 @@ class Stream final {
};

inline std::ostream& operator<<(std::ostream& os, const Stream& s) {
os << "Stream(device_type=" << static_cast<int>(s.device_type())
<< ", device_index=" << static_cast<int>(s.device_index())
<< ", id=" << s.id() << ")";
// Format: "stream {id} on device {device_type}:{device_index}"
os << "stream " << s.id() << " on device " << s.device();
return os;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,5 @@ void synchronize(int64_t device_index = -1);

} // namespace torch::cuda
namespace at::cuda {
using torch::cuda::device_count;
using torch::cuda::is_available;
using torch::cuda::synchronize;
} // namespace at::cuda
2 changes: 1 addition & 1 deletion test/cpp/compat/ATen_TensorAccessor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ TEST(TensorAccessorTest, PackedAccessorWithIntType) {

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(TensorAccessorTest, PackedAccessorCUDA) {
if (at::cuda::is_available()) {
if (torch::cuda::is_available()) {
// Create CUDA tensor
at::Tensor tensor =
at::arange(12, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA))
Expand Down
109 changes: 47 additions & 62 deletions test/cpp/compat/c10_ScalarType_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAGuard.h>
#endif
#include <sstream>
#include "ATen/ATen.h"
#include "gtest/gtest.h"
#include "paddle/phi/common/float16.h"
Expand Down Expand Up @@ -92,8 +91,32 @@ TEST(TensorBaseTest, TypeCheckingAPIs) {
ASSERT_FALSE(bool_tensor.is_signed());
}

TEST(ScalarTypeCompatTest, ScalarTypeUtilityBranches) {
TEST(ScalarTypeTest, RestoredCompatScalarTypesKeepSourceLevelSemantics) {
EXPECT_EQ(static_cast<int>(c10::ScalarType::ComplexHalf), 8);
EXPECT_EQ(static_cast<int>(c10::ScalarType::QInt8), 12);
EXPECT_EQ(static_cast<int>(c10::ScalarType::Bits16), 22);
EXPECT_EQ(static_cast<int>(c10::ScalarType::Float8_e5m2fnuz), 25);
EXPECT_EQ(static_cast<int>(c10::ScalarType::Float4_e2m1fn_x2), 45);
EXPECT_EQ(c10::NumScalarTypes, 47);

EXPECT_EQ(c10::kComplexHalf, c10::ScalarType::ComplexHalf);
EXPECT_EQ(c10::kQInt8, c10::ScalarType::QInt8);
EXPECT_EQ(c10::kBits16, c10::ScalarType::Bits16);
EXPECT_EQ(c10::kFloat8_e4m3fnuz, c10::ScalarType::Float8_e4m3fnuz);
EXPECT_EQ(c10::kFloat8_e8m0fnu, c10::ScalarType::Float8_e8m0fnu);
EXPECT_EQ(c10::kFloat4_e2m1fn_x2, c10::ScalarType::Float4_e2m1fn_x2);
EXPECT_EQ(c10::kUndefined, c10::ScalarType::Undefined);

EXPECT_STREQ(c10::toString(c10::ScalarType::ComplexHalf), "ComplexHalf");
EXPECT_STREQ(c10::toString(c10::ScalarType::QInt8), "QInt8");
EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt8), "QUInt8");
EXPECT_STREQ(c10::toString(c10::ScalarType::QInt32), "QInt32");
EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt4x2), "QUInt4x2");
EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt2x4), "QUInt2x4");
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits1x8), "Bits1x8");
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits2x4), "Bits2x4");
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits4x2), "Bits4x2");
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits8), "Bits8");
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits16), "Bits16");
EXPECT_STREQ(c10::toString(c10::ScalarType::Float8_e5m2fnuz),
"Float8_e5m2fnuz");
Expand All @@ -104,65 +127,27 @@ TEST(ScalarTypeCompatTest, ScalarTypeUtilityBranches) {
EXPECT_STREQ(c10::toString(c10::ScalarType::Float4_e2m1fn_x2),
"Float4_e2m1fn_x2");
EXPECT_STREQ(c10::toString(c10::ScalarType::Undefined), "Undefined");
EXPECT_STREQ(c10::toString(static_cast<c10::ScalarType>(-1)),
"UNKNOWN_SCALAR");

EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt8), static_cast<size_t>(1));
EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt4x2),
static_cast<size_t>(1));
EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt32), static_cast<size_t>(4));
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits16), static_cast<size_t>(4));
EXPECT_THROW(c10::elementSize(c10::ScalarType::Undefined), ::std::exception);

EXPECT_TRUE(c10::isIntegralType(c10::ScalarType::Bool, true));
EXPECT_FALSE(c10::isIntegralType(c10::ScalarType::Bool, false));
EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e5m2));
EXPECT_FALSE(c10::isFloat8Type(c10::ScalarType::Float8_e4m3fnuz));
EXPECT_TRUE(c10::isReducedFloatingType(c10::ScalarType::BFloat16));
EXPECT_TRUE(c10::isFloatingType(c10::ScalarType::Float));
EXPECT_FALSE(c10::isComplexType(c10::ScalarType::ComplexHalf));

EXPECT_TRUE(c10::isSignedType(c10::ScalarType::Int1));
EXPECT_FALSE(c10::isSignedType(c10::ScalarType::UInt3));
EXPECT_FALSE(c10::isSignedType(c10::ScalarType::QUInt8));
EXPECT_TRUE(c10::isSignedType(c10::ScalarType::Float8_e5m2fnuz));
EXPECT_THROW(c10::isSignedType(c10::ScalarType::Undefined), ::std::exception);

std::ostringstream oss;
oss << c10::ScalarType::UInt7;
EXPECT_EQ(oss.str(), "UInt7");
}

TEST(ScalarTypeCompatTest, AdditionalEnumAndPredicateBranches) {
EXPECT_STREQ(c10::toString(c10::ScalarType::QInt8), "QInt8");
EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt8), "QUInt8");
EXPECT_STREQ(c10::toString(c10::ScalarType::QInt32), "QInt32");
EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt4x2), "QUInt4x2");
EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt2x4), "QUInt2x4");
EXPECT_STREQ(c10::toString(c10::ScalarType::ComplexHalf), "ComplexHalf");
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits2x4), "Bits2x4");
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits4x2), "Bits4x2");
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits8), "Bits8");

EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt8), static_cast<size_t>(1));
EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt2x4),
static_cast<size_t>(1));
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits2x4), static_cast<size_t>(1));
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits4x2), static_cast<size_t>(1));
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits8), static_cast<size_t>(1));

EXPECT_TRUE(c10::isIntegralType(c10::ScalarType::UInt64, false));
EXPECT_FALSE(c10::isIntegralType(c10::ScalarType::Float, true));
EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e4m3fn));
EXPECT_TRUE(c10::isReducedFloatingType(c10::ScalarType::Half));
EXPECT_FALSE(c10::isReducedFloatingType(c10::ScalarType::Float));
EXPECT_TRUE(c10::isFloatingType(c10::ScalarType::Half));
EXPECT_TRUE(c10::isComplexType(c10::ScalarType::ComplexFloat));

EXPECT_TRUE(c10::isSignedType(c10::ScalarType::QInt8));
EXPECT_TRUE(c10::isSignedType(c10::ScalarType::ComplexHalf));
EXPECT_FALSE(c10::isSignedType(c10::ScalarType::Byte));
EXPECT_FALSE(c10::isSignedType(c10::ScalarType::Bool));
EXPECT_THROW(c10::isSignedType(c10::ScalarType::NumOptions),
::std::exception);
EXPECT_EQ(c10::elementSize(c10::ScalarType::ComplexHalf),
sizeof(at::Half) * 2);
EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt8), 1U);
EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt8), 1U);
EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt32), 4U);
EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt4x2), 1U);
EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt2x4), 1U);
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits1x8), 1U);
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits2x4), 1U);
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits4x2), 1U);
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits8), 1U);
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits16), 2U);
EXPECT_EQ(c10::elementSize(c10::ScalarType::Float8_e5m2fnuz), 1U);
EXPECT_EQ(c10::elementSize(c10::ScalarType::Float8_e4m3fnuz), 1U);
EXPECT_EQ(c10::elementSize(c10::ScalarType::Float8_e8m0fnu), 1U);
EXPECT_EQ(c10::elementSize(c10::ScalarType::Float4_e2m1fn_x2), 1U);

EXPECT_TRUE(c10::isComplexType(c10::ScalarType::ComplexHalf));
EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e5m2fnuz));
EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e4m3fnuz));
EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e8m0fnu));
EXPECT_TRUE(c10::isReducedFloatingType(c10::ScalarType::Float4_e2m1fn_x2));
}
9 changes: 5 additions & 4 deletions test/cpp/compat/c10_storage_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
#include "paddle/phi/backends/gpu/gpu_info.h"

// Forward-declare getCUDADeviceAllocator to avoid include-order conflicts
// between ATen/cuda/CUDAContextLight.h (defines at::cuda::is_available inline)
// and torch/cuda.h (adds `using torch::cuda::is_available` to at::cuda).
// between ATen/cuda/CUDAContextLight.h (defines torch::cuda::is_available
// inline) and torch/cuda.h (adds `using torch::cuda::is_available` to
// at::cuda).
namespace at::cuda {
c10::Allocator* getCUDADeviceAllocator();
} // namespace at::cuda
Expand Down Expand Up @@ -242,7 +243,7 @@ TEST(StorageTest, DeviceAndDeviceTypeAPIs) {
ASSERT_EQ(place.GetType(), phi::AllocationType::CPU);

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (at::cuda::is_available()) {
if (torch::cuda::is_available()) {
at::TensorBase cuda_tensor = at::ones(
{2, 3}, c10::TensorOptions().dtype(at::kFloat).device(at::kCUDA));
const c10::Storage& cuda_storage = cuda_tensor.storage();
Expand Down Expand Up @@ -1041,7 +1042,7 @@ TEST(StorageTest, ReferenceSemanticsSetNbytesVisibleThroughCopy) {
TEST(StorageTest, CUDAAllocatorZeroBytePreservesDevice) {
// getCUDADeviceAllocator()->allocate(0) must return a DataPtr whose device
// is the current CUDA device, not a default-constructed CPU DataPtr.
if (!at::cuda::is_available()) {
if (!torch::cuda::is_available()) {
return; // No CUDA device, skip
}

Expand Down
2 changes: 1 addition & 1 deletion test/cpp/compat/compat_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ TEST(compat_basic_test, BasicCase) {
TEST(TestDevice, DeviceAPIsOnCUDA) {
// Test device related APIs on CUDA if available
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (at::cuda::is_available()) {
if (torch::cuda::is_available()) {
at::TensorBase cuda_tensor = at::ones(
{2, 3}, c10::TensorOptions().dtype(at::kFloat).device(at::kCUDA));

Expand Down
Loading