Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 35 additions & 22 deletions backends/transforms/remove_permutes_around_elementwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ class Subgraph:
default_factory=set
)

# Ops explicitly listed as permutable. This includes non-pointwise ops
# that need special dimension-argument handling (cat, mean, sum, slice)
# and quantize/dequantize ops not tagged as pointwise in ATen.
# In addition to this set, any op tagged with torch.Tag.pointwise is
# automatically considered permutable (see is_node_permutable).
permutable_ops: set[EdgeOpOverload] = {
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.clamp.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
# Ops that require special handling.
# Ops that require special handling of dimension arguments.
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.sum.dim_IntList,
Expand All @@ -67,7 +68,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
end_permute = [start_permute.index(i) for i in range(len(start_permute))]

for user in node.users:
if user.target not in self.permutable_ops:
if user.target not in self.permutable_ops and not self._is_pointwise(
user.target
):
continue
# Create a separate subgraph for each user since there may be cases
# where only a portion of the users are permutable.
Expand Down Expand Up @@ -159,24 +162,34 @@ def _get_node_rank(self, node: torch.fx.Node) -> int | None:
return len(val.shape)
return None

@staticmethod
def _is_pointwise(target) -> bool:
"""Check if a target op is tagged as pointwise in ATen."""
op = getattr(target, "_op", None)
if op is not None and hasattr(op, "tags"):
return torch.Tag.pointwise in op.tags
return False

def is_node_permutable(self, node: torch.fx.Node) -> bool:
if node.target not in self.permutable_ops:
return False
if node.target in (
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.sum.dim_IntList,
):
# keepdim should be True.
if len(node.args) >= 3:
if not node.args[2]:
return False
elif "keepdim" in node.kwargs:
if not node.kwargs["keepdim"]:
if node.target in self.permutable_ops:
# Special-case validation for dim-based ops.
if node.target in (
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.sum.dim_IntList,
):
# keepdim should be True.
if len(node.args) >= 3:
if not node.args[2]:
return False
elif "keepdim" in node.kwargs:
if not node.kwargs["keepdim"]:
return False
else:
# Default keepdim is False.
return False
else:
# Default keepdim is False.
return False
return True
return True
# Accept any op tagged as pointwise in ATen (elementwise).
return self._is_pointwise(node.target)

def permute_subgraph(self, subgraph: Subgraph) -> None:
# Skip incoming permutes.
Expand Down
Loading