Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions exir/tensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -69,30 +69,32 @@
"""
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_size_oblivious,
guard_or_true,
)

for _, s in enumerate(stride):
if guard_or_false(s == 0):
raise ValueError("0 in strides is not supported for ExecuTorch.")
for s in stride:
torch._check(s != 0, lambda: "0 in strides is not supported for ExecuTorch.")

class K(NamedTuple):
stride: int

def __lt__(self, other):
return guard_size_oblivious(self.stride < other.stride)

def __gt__(self, other):
return guard_size_oblivious(self.stride > other.stride)

def __le__(self, other):
return guard_size_oblivious(self.stride <= other.stride)

def __ge__(self, other):
return guard_size_oblivious(self.stride >= other.stride)

def __eq__(self, other):
return guard_size_oblivious(self.stride == other.stride)
# For backed/concrete strides this is practically a `<` operation.
# For unbacked, we return True if `<` is statically known, then
# try to answer symbolically with stride-ordering semantics:
# u0 < u0 -> False
# u0 < u1 (no info) -> False
# u0 < 2 * u0 -> True (divisibility)
# 1 < u0 -> True (1 divides anything; unprovable equality treated optimistically)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stride of 1 will be in every tensor do we expect this to matter?

return (
guard_or_false(
self.stride < other.stride
) # statically known inequality
or (
guard_or_false(other.stride % self.stride == 0)
and guard_or_true(self.stride != other.stride)
) # symbolic inequality (e.g. u0 < 2048 * u0)
)

sorted_dims = [
i[0] for i in sorted(enumerate(stride), key=lambda x: K(x[1]), reverse=True)
Expand Down
45 changes: 44 additions & 1 deletion exir/tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,52 @@ def test_dim_order_from_stride(self) -> None:
# dim[2] is broadcasting dim
# shape = (5, 1, 15, 10)
strides = (10, 10, 0, 1)
with self.assertRaises(ValueError):
# torch._check raises RuntimeError on concrete 0.
with self.assertRaises(RuntimeError):
dim_order = dim_order_from_stride(strides)

def test_dim_order_from_stride_unbacked(self) -> None:
"""
dim_order_from_stride should produce a sane permutation even when the
strides contain unbacked SymInts. The comparator falls back to
divisibility-based reasoning so common cases like (1, u0) and
(u0, 2 * u0) order correctly.
"""
from torch.fx.experimental.symbolic_shapes import ShapeEnv

shape_env = ShapeEnv()
u0 = shape_env.create_unbacked_symint()
u1 = shape_env.create_unbacked_symint()

# 1 < u0 should be True via divisibility (u0 % 1 == 0) + optimistic
# `1 != u0`. Descending sort puts u0 outer, stride 1 inner.
dim_order = dim_order_from_stride((1, u0))
self.assertEqual((1, 0), dim_order)

# u0 < 2 * u0 should be True via divisibility ((2*u0) % u0 == 0) and
# provable inequality (u0 != 0 after torch._check).
dim_order = dim_order_from_stride((u0, 2 * u0))
self.assertEqual((1, 0), dim_order)

# Mixed concrete + symbolic: (1, u0, 2 * u0). Descending stride order
# is (2*u0, u0, 1) -> indices (2, 1, 0).
dim_order = dim_order_from_stride((1, u0, 2 * u0))
self.assertEqual((2, 1, 0), dim_order)

# u0 < u1 (independent unbackeds) is genuinely ambiguous; stable sort
# preserves original order under reverse=True (no swap on ambiguous).
dim_order = dim_order_from_stride((u0, u1))
self.assertEqual((0, 1), dim_order)

# u0 < u0 is False both ways (symmetric); stable sort preserves order.
dim_order = dim_order_from_stride((u0, u0))
self.assertEqual((0, 1), dim_order)

# Unbacked stride of 0 (concrete 0 mixed with unbacked) -> RuntimeError
# via torch._check.
with self.assertRaises(RuntimeError):
dim_order_from_stride((u0, 0, 1))

def test_strides_from_dim_order(self) -> None:
sizes = []
dim_order = []
Expand Down
Loading