diff --git a/backends/transforms/remove_permutes_around_elementwise_ops.py b/backends/transforms/remove_permutes_around_elementwise_ops.py index 28b739c8d91..8a4170e207e 100644 --- a/backends/transforms/remove_permutes_around_elementwise_ops.py +++ b/backends/transforms/remove_permutes_around_elementwise_ops.py @@ -40,14 +40,15 @@ class Subgraph: default_factory=set ) + # Ops explicitly listed as permutable. This includes non-pointwise ops + # that need special dimension-argument handling (cat, mean, sum, slice) + # and quantize/dequantize ops not tagged as pointwise in ATen. + # In addition to this set, any op tagged with torch.Tag.pointwise is + # automatically considered permutable (see is_node_permutable). permutable_ops: set[EdgeOpOverload] = { - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.hardtanh.default, - exir_ops.edge.aten.clamp.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - # Ops that require special handling. + # Ops that require special handling of dimension arguments. exir_ops.edge.aten.cat.default, exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.sum.dim_IntList, @@ -67,7 +68,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: end_permute = [start_permute.index(i) for i in range(len(start_permute))] for user in node.users: - if user.target not in self.permutable_ops: + if user.target not in self.permutable_ops and not self._is_pointwise( + user.target + ): continue # Create a separate subgraph for each user since there may be cases # where only a portion of the users are permutable. @@ -159,24 +162,34 @@ def _get_node_rank(self, node: torch.fx.Node) -> int | None: return len(val.shape) return None + @staticmethod + def _is_pointwise(target) -> bool: + """Check if a target op is tagged as pointwise in ATen.""" + op = getattr(target, "_op", None) + if op is not None and hasattr(op, "tags"): + return torch.Tag.pointwise in op.tags + return False + def is_node_permutable(self, node: torch.fx.Node) -> bool: - if node.target not in self.permutable_ops: - return False - if node.target in ( - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten.sum.dim_IntList, - ): - # keepdim should be True. - if len(node.args) >= 3: - if not node.args[2]: - return False - elif "keepdim" in node.kwargs: - if not node.kwargs["keepdim"]: + if node.target in self.permutable_ops: + # Special-case validation for dim-based ops. + if node.target in ( + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.sum.dim_IntList, + ): + # keepdim should be True. + if len(node.args) >= 3: + if not node.args[2]: + return False + elif "keepdim" in node.kwargs: + if not node.kwargs["keepdim"]: + return False + else: + # Default keepdim is False. return False - else: - # Default keepdim is False. - return False - return True + return True + # Accept any op tagged as pointwise in ATen (elementwise). + return self._is_pointwise(node.target) def permute_subgraph(self, subgraph: Subgraph) -> None: # Skip incoming permutes.