Relax stride check in unit dimensions.#6285
Conversation
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
|
CI MESSAGE: [48103266]: BUILD STARTED |
Greptile SummaryThis PR fixes a DLPack interoperability bug where older PyTorch versions report stride Confidence Score: 5/5Safe to merge — the fix is minimal, correct, and well-targeted at the reported interoperability issue. No P0 or P1 issues found. The new logic is simpler and strictly more correct than what it replaces: skipping the stride check for unit extents is sound because a stride is only meaningful when the index can be non-zero. The stride_from_shape accumulation is unaffected (multiplying by 1 is a no-op). The zero-volume early-return is a clean orthogonal improvement. All prior review concerns have been addressed or resolved in the thread. No files require special attention.
|
| Filename | Overview |
|---|---|
| dali/python/backend_impl.cc | Simplifies CheckContiguousTensor by skipping the stride check for unit-extent dimensions and adding an early return for zero-volume tensors — both changes are logically correct and the stride_from_shape accumulation continues to work properly since multiplying by 1 is a no-op. |
| dali/test/python/experimental_mode/test_interop.py | Adds test_from_torch parameterized over CPU/GPU, exercising the leading-unit-dimension case (shape (1,2,2)) that previously failed with DLPack strides from older PyTorch builds; the version guard is justified as discussed in the review thread. |
Flowchart
%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[CheckContiguousTensor called] --> B{num_strides == num_extents?}
B -->|No| C[DALI_ENFORCE fails]
B -->|Yes| D[Scan dims for shape == 0]
D -->|Any dim is 0| E[Return early - zero-volume tensor]
D -->|No zero dims| F[stride_from_shape = element_size]
F --> G{i >= 0?}
G -->|No| H[All strides valid]
G -->|Yes| I{shape at i == 1?}
I -->|Yes - unit extent| J[Skip stride check]
I -->|No| K{strides at i == stride_from_shape?}
K -->|No| L[DALI_ENFORCE fails]
K -->|Yes| M[stride_from_shape times shape at i, decrement i]
J --> M
M --> G
Reviews (2): Last reviewed commit: "Remove dead variable." | Re-trigger Greptile
dali/python/backend_impl.cc
Outdated
|
|
||
| int64_t stride_from_shape = element_size; | ||
| int64_t stride_from_shape_collapsed = 1; | ||
| int64_t last_non_one_dim = 1; |
There was a problem hiding this comment.
Dead variable left over from refactor
last_non_one_dim is initialised but never read after the simplification — it was part of the old stride_from_shape_collapsed logic that was removed. Most compilers will warn about it.
| int64_t last_non_one_dim = 1; | |
| int64_t stride_from_shape = element_size; |
| if version.parse(torch.__version__) < version.parse("2.6.0"): | ||
| raise SkipTest("Requires PyTorch >= 2.6.0") |
There was a problem hiding this comment.
Test guards against the version that triggers the bug
The PR description states the problem occurs in older PyTorch versions (strides of 1 for unit dimensions in DLPack). This test skips on < 2.6.0, meaning it never runs on the versions that actually exposed the bug. Consider adding a separate test path (or removing the version guard) to confirm the fix works on the affected older versions, or at minimum document why this gate is needed.
There was a problem hiding this comment.
@greptileai: The tests fails with PyTorch 2.7.1, so the check against 2.6.0 doesn't defeat it.
There was a problem hiding this comment.
You're right—if the stride issue still occurs in PyTorch 2.7.1, then the >= 2.6.0 guard makes sense. The test will run on 2.7.1+ and verify the fix handles those cases. I misread the situation as the bug only affecting older versions.
Tip: You can give me more context with .greptile/rules so I get this right next time.
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
|
CI MESSAGE: [48104327]: BUILD STARTED |
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
|
CI MESSAGE: [48104565]: BUILD STARTED |
|
CI MESSAGE: [48104565]: BUILD PASSED |
Category:
Bug fix (non-breaking change which fixes an issue)
Description:
This PR relaxes the stride check in unit dimensions. Some older versions of torch always give stride 1 for unit dimensions in DLPack capsules, causing an unnecessary error on contiguous data.
Example:
shape = (1, 3, 15, 20)
dlpack strides: (1, 300, 20, 1)
Additional information:
Affected modules and functionalities:
Key points relevant for the review:
Tests:
Checklist
Documentation
DALI team only
Requirements
REQ IDs: N/A
JIRA TASK: N/A