diff --git a/backends/apple/coreml/partition/coreml_partitioner.py b/backends/apple/coreml/partition/coreml_partitioner.py index 57bd793a10a..b4e550d8479 100644 --- a/backends/apple/coreml/partition/coreml_partitioner.py +++ b/backends/apple/coreml/partition/coreml_partitioner.py @@ -63,6 +63,26 @@ ) +_ARG_MIN_MAX_TARGETS = ( + torch.ops.aten.argmax.default, + torch.ops.aten.argmin.default, + exir_ops.edge.aten.argmax.default, + exir_ops.edge.aten.argmin.default, +) + + +def _is_arg_min_max_over_flattened_input(node: torch.fx.Node) -> bool: + """``argmin``/``argmax`` with ``dim=None`` reduces over the flattened input. + + CoreML doesn't support that reduction shape and intermittently crashes + the process at runtime — see pytorch/executorch#11715. + """ + if node.target not in _ARG_MIN_MAX_TARGETS: + return False + dim = node.args[1] if len(node.args) >= 2 else node.kwargs.get("dim", None) + return dim is None + + def _is_view_op(op: torch._ops.OpOverload) -> bool: schema = op._schema if len(schema.arguments) == 0: @@ -132,6 +152,13 @@ def should_override_support(self, node) -> bool: ) return True + if _is_arg_min_max_over_flattened_input(node): + self.log_once( + "torch.ops.aten.{argmax, argmin}.default with dim=None is " + "not supported by CoreML. Overriding op support." + ) + return True + # TODO: enable this after bugs in ExecuTorch's partitioner are fixed # # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args # # in the placeholders due to partitioning, which CoreML does not support diff --git a/backends/apple/coreml/test/test_coreml_partitioner.py b/backends/apple/coreml/test/test_coreml_partitioner.py index a2321ee199f..0e75d6024e4 100644 --- a/backends/apple/coreml/test/test_coreml_partitioner.py +++ b/backends/apple/coreml/test/test_coreml_partitioner.py @@ -386,6 +386,53 @@ def forward(self, x): self.assertIn("executorch_call_delegate", op_names) self.assertNotIn("aten.randn.default", op_names) + def test_argmax_argmin_dim_none_is_skipped(self): + """ + Regression test for https://github.com/pytorch/executorch/issues/11715. + + argmax/argmin with dim=None reduces over the flattened tensor, which + CoreML does not support; the resulting model intermittently crashes + the process at runtime. The partitioner must reject these so they + fall back to the portable backend, while still delegating the + ordinary dim=int form. + """ + + class FlatModel(torch.nn.Module): + def forward(self, x): + return torch.argmax(x, dim=None, keepdim=False) + torch.argmin( + x, dim=None + ) + + ep = torch.export.export( + FlatModel().eval(), (torch.randn(10, 10),), strict=True + ) + edge = executorch.exir.to_edge_transform_and_lower( + ep, partitioner=[CoreMLPartitioner()] + ) + op_names = [ + n.target.__name__ + for n in edge.exported_program().graph.nodes + if n.op == "call_function" + ] + self.assertIn("aten.argmax.default", op_names) + self.assertIn("aten.argmin.default", op_names) + + class DimModel(torch.nn.Module): + def forward(self, x): + return torch.argmax(x, dim=1) + + ep = torch.export.export(DimModel().eval(), (torch.randn(10, 10),), strict=True) + edge = executorch.exir.to_edge_transform_and_lower( + ep, partitioner=[CoreMLPartitioner()] + ) + op_names = [ + n.target.__name__ + for n in edge.exported_program().graph.nodes + if n.op == "call_function" + ] + self.assertIn("executorch_call_delegate", op_names) + self.assertNotIn("aten.argmax.default", op_names) + def test_deprecation_warning_for_to_backend_workflow(self): """ Test that the deprecated to_edge + to_backend workflow shows a deprecation warning.