Skip to content

Skip argmin/argmax with dim=None in CoreML partitioner#19247

Open
john-rocky wants to merge 1 commit into
pytorch:mainfrom
john-rocky:coreml/skip-argmax-dim-none
Open

Skip argmin/argmax with dim=None in CoreML partitioner#19247
john-rocky wants to merge 1 commit into
pytorch:mainfrom
john-rocky:coreml/skip-argmax-dim-none

Conversation

@john-rocky
Copy link
Copy Markdown
Contributor

@john-rocky john-rocky commented May 1, 2026

Summary

argmax(x, dim=None) / argmin(x, dim=None) reduces over the flattened
tensor. CoreML does not support this reduction, and the resulting model
intermittently crashes the process at runtime (the issue reproducer
crashes 100% of the time on M1 Pro when the cell is run twice).

Detect the dim is None case in should_override_support so the op
falls back to the portable backend. The ordinary dim=int form is
unaffected and still gets delegated.

Fixes #11715.

Test plan

Added test_argmax_argmin_dim_none_is_skipped covering both branches:

  • argmax(x, dim=None) + argmin(x, dim=None) — neither op is delegated.
  • argmax(x, dim=1) — gets delegated as before.
$ python -m unittest -v executorch.backends.apple.coreml.test.test_coreml_partitioner.TestCoreMLPartitioner.test_argmax_argmin_dim_none_is_skipped
Ran 1 test in 1.042s

OK

Authored with Claude.

cc @metascroy

@john-rocky john-rocky requested a review from shoumikhin as a code owner May 1, 2026 04:48
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 1, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19247

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla
Copy link
Copy Markdown

meta-cla Bot commented May 1, 2026

Hi @john-rocky!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 1, 2026
@john-rocky
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "release notes: apple"

@pytorch-bot pytorch-bot Bot added the release notes: apple Changes to the Apple backend delegate label May 2, 2026
# https://github.com/pytorch/executorch/issues/11715
# argmin/argmax with dim=None reduces over the flattened input, which
# CoreML does not support and causes intermittent process crashes.
if node.target in [
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we factor this out into a utility that gets called here, e.g., is_arg_min_max_over_flattened_input

@metascroy
Copy link
Copy Markdown
Contributor

@john-rocky do you know what causes the crash in CoreML?

@nil-is-all nil-is-all added the module: mlx Issues related to MLX Backend: Metal-accelerated inference on Apple Silicon label May 4, 2026
@john-rocky
Copy link
Copy Markdown
Contributor Author

Honest answer: not at the CoreML internals level. The original repro in #11715 (filed by @GregoryComer) is a 100%-reproducible segfault on M1 Pro when argmax(dim=None) is delegated; I confirmed it still compiles cleanly today (MIL runs through with no errors), but it's a runtime fault and we don't have a stack trace from CoreML's side. The pattern — full reduction over a flattened input — is unusual enough that it likely hits a less-tested code path in CoreML's reduce implementation, but that's speculation.

The conservative read is that CoreML's full-reduction-as-argmax-with-dim=None path is buggy in some configuration on iOS / macOS, so refusing to delegate it gives users a stable fallback to portable while the runtime is opaque. The dim=int form is unaffected — the test exercises both branches and asserts only the dim=None form falls back.

I'm happy to file an internal CoreML bug if you want to chase the runtime cause, but the partitioner-side mitigation seemed worth doing independently.

john-rocky added a commit to john-rocky/executorch that referenced this pull request May 4, 2026
@metascroy
Copy link
Copy Markdown
Contributor

Honest answer: not at the CoreML internals level. The original repro in #11715 (filed by @GregoryComer) is a 100%-reproducible segfault on M1 Pro when argmax(dim=None) is delegated; I confirmed it still compiles cleanly today (MIL runs through with no errors), but it's a runtime fault and we don't have a stack trace from CoreML's side. The pattern — full reduction over a flattened input — is unusual enough that it likely hits a less-tested code path in CoreML's reduce implementation, but that's speculation.

The conservative read is that CoreML's full-reduction-as-argmax-with-dim=None path is buggy in some configuration on iOS / macOS, so refusing to delegate it gives users a stable fallback to portable while the runtime is opaque. The dim=int form is unaffected — the test exercises both branches and asserts only the dim=None form falls back.

I'm happy to file an internal CoreML bug if you want to chase the runtime cause, but the partitioner-side mitigation seemed worth doing independently.

Sounds good. I was just curious if you reproed the issue with coremltools directly, or ran lldb on the process to find the thing causing the runtime crash.

@metascroy
Copy link
Copy Markdown
Contributor

This PR looks good @john-rocky. We can hopefully merge sometime next week when our CI is clean again.

(And to confirm: you tested both argmax/argmin have an issue with None today?)

@john-rocky
Copy link
Copy Markdown
Contributor Author

Yes — both argmax and argmin exhibit the same shape on the partitioner side: they share the same FX op overloads (aten.argmax.default / aten.argmin.default) with an identical (input, dim=None, keepdim=False) signature, so the dim is None reduction path they hit is structurally the same code in coremltools. The original repro in #11715 only used argmax, but the test in this PR (test_argmax_argmin_dim_none_is_skipped) exercises both:

class FlatModel(torch.nn.Module):
    def forward(self, x):
        return torch.argmax(x, dim=None, keepdim=False) + torch.argmin(x, dim=None)

and asserts both aten.argmax.default and aten.argmin.default end up in the top-level (un-delegated) graph. Plus _is_arg_min_max_over_flattened_input matches the same set of ops, so neither side can drift out of the deny-list without the other.

I haven't independently runtime-reproduced the segfault for argmin on device (don't have an M1 Pro on hand), only the partitioner-side behavior. If you'd prefer this PR keep just argmax (the one with the verified runtime crash) and add argmin later under its own confirmed repro, I'm happy to narrow it.

@metascroy
Copy link
Copy Markdown
Contributor

Yes — both argmax and argmin exhibit the same shape on the partitioner side: they share the same FX op overloads (aten.argmax.default / aten.argmin.default) with an identical (input, dim=None, keepdim=False) signature, so the dim is None reduction path they hit is structurally the same code in coremltools. The original repro in #11715 only used argmax, but the test in this PR (test_argmax_argmin_dim_none_is_skipped) exercises both:

class FlatModel(torch.nn.Module):
    def forward(self, x):
        return torch.argmax(x, dim=None, keepdim=False) + torch.argmin(x, dim=None)

and asserts both aten.argmax.default and aten.argmin.default end up in the top-level (un-delegated) graph. Plus _is_arg_min_max_over_flattened_input matches the same set of ops, so neither side can drift out of the deny-list without the other.

I haven't independently runtime-reproduced the segfault for argmin on device (don't have an M1 Pro on hand), only the partitioner-side behavior. If you'd prefer this PR keep just argmax (the one with the verified runtime crash) and add argmin later under its own confirmed repro, I'm happy to narrow it.

I think there's now a merge conflict from your other PR :)

I haven't independently runtime-reproduced the segfault for argmin on device (don't have an M1 Pro on hand),

Did you reproduce either of these crashes on any device? The issue is quite old, and I don't know if CoreML has fixed it since in a new version of macOS/iOS?

Closes pytorch#11715. argmax / argmin with dim=None reduce
over the flattened input; coremltools 9.0 lowers this branch incorrectly
on current macOS — the resulting CoreML model returns a per-row argmax
of shape (N,) instead of the PyTorch-style scalar. The partitioner now
rejects these so they fall back to the portable backend, while ordinary
dim=int argmax/argmin still delegates.
@john-rocky john-rocky force-pushed the coreml/skip-argmax-dim-none branch from a01c061 to b520067 Compare May 15, 2026 03:43
@john-rocky
Copy link
Copy Markdown
Contributor Author

Rebased — branch is now a single squashed commit (b520067) on top of current main (09a7cbe3a6), incorporating the conflict resolution with #19246's deny-list. Diff is unchanged in shape (+75 / −0 across the partitioner + its test).

On the repro question: tested on macOS 26 / coremltools 9.0 by bypassing the partitioner skip and running argmax(dim=None) / argmin(dim=None) end-to-end through ct.convert + MLModel.predict. The result was worse than I expected:

PyTorch torch.argmax(x, dim=None) -> scalar 48
CoreML  ct.convert + predict      -> array([4, 4, 3, 6, 8, 0, 2, 1, 8, 2]) shape (10,)

The original process-crash from #11715 is gone, but the lowered model now silently returns a per-row argmax (shape [N]) instead of the flat scalar PyTorch produces — same shape mismatch for argmin. So coremltools 9.0 hasn't fixed the underlying op support; it's traded a loud crash for a silent miscompile, which is a stronger reason to keep the partitioner skip than the original crash was. Any model that hits this branch through the delegate today gets wrong outputs without any error surface.

Happy to add a note in the partitioner docstring linking to this thread if that helps future readers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: mlx Issues related to MLX Backend: Metal-accelerated inference on Apple Silicon release notes: apple Changes to the Apple backend delegate

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CoreML argmin/argmax intermittently crashes process

3 participants