Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions dali/python/backend_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,18 @@ void CheckContiguousTensor(const TStrides &strides, int num_strides,
const TShape &shape, int num_extents, size_t element_size) {
DALI_ENFORCE(num_strides == num_extents,
"There should be exactly as many strides as there are extents in array shape.");
for (int i = 0; i < num_extents; i++)
if (shape[i] == 0)
return; // The volume is 0, the strides will never be used to compute an actual address

int64_t stride_from_shape = element_size;
int64_t stride_from_shape_collapsed = 1;
int64_t last_non_one_dim = 1;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Dead variable left over from refactor

last_non_one_dim is initialised but never read after the simplification — it was part of the old stride_from_shape_collapsed logic that was removed. Most compilers will warn about it.

Suggested change
int64_t last_non_one_dim = 1;
int64_t stride_from_shape = element_size;

for (int i = num_strides - 1; i >= 0; i--) {
DALI_ENFORCE(strides[i] == stride_from_shape || strides[i] == stride_from_shape_collapsed,
DALI_ENFORCE(shape[i] == 1 || // ignore unit extents - the stride won't be used anyway
strides[i] == stride_from_shape,
make_string("Strided data not supported. Dimension ", i, " has stride ", strides[i],
" whereas densely packed data of this shape would have a stride ", stride_from_shape));
stride_from_shape *= shape[i];
// for shapes [1, 1, 5] leading dimensions may not contribute to stride
if (shape[i] != 1) {
stride_from_shape_collapsed *= last_non_one_dim;
last_non_one_dim = shape[i];
}
}
}

Expand Down
14 changes: 14 additions & 0 deletions dali/test/python/experimental_mode/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ def test_torch_nocopy(device: str):
np.testing.assert_array_equal(ndd_tensor.cpu(), [42, 2, 3])


@attr("pytorch")
@params(("cpu",), ("gpu",))
def test_from_torch(device: str):
import torch

if version.parse(torch.__version__) < version.parse("2.6.0"):
raise SkipTest("Requires PyTorch >= 2.6.0")
Comment on lines +97 to +98
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Test guards against the version that triggers the bug

The PR description states the problem occurs in older PyTorch versions (strides of 1 for unit dimensions in DLPack). This test skips on < 2.6.0, meaning it never runs on the versions that actually exposed the bug. Consider adding a separate test path (or removing the version guard) to confirm the fix works on the affected older versions, or at minimum document why this gate is needed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@greptileai: The tests fails with PyTorch 2.7.1, so the check against 2.6.0 doesn't defeat it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right—if the stride issue still occurs in PyTorch 2.7.1, then the >= 2.6.0 guard makes sense. The test will run on 2.7.1+ and verify the fix handles those cases. I misread the situation as the bug only affecting older versions.

Tip: You can give me more context with .greptile/rules so I get this right next time.


inpt = torch.arange(1, 5).reshape((1, 2, 2))
ndd_tensor = ndd.as_tensor(inpt)

np.testing.assert_array_equal(ndd_tensor.cpu(), [[[1, 2], [3, 4]]])


@attr("pytorch")
@params(("cpu",), ("gpu",))
def test_tensor_to_torch(device: str):
Expand Down
Loading