Skip to content

feat(MLX): Add handlers for aten.flip and aten.trunc#18982

Open
Jah-yee wants to merge 1 commit intopytorch:mainfrom
Jah-yee:mlx-flip-handler
Open

feat(MLX): Add handlers for aten.flip and aten.trunc#18982
Jah-yee wants to merge 1 commit intopytorch:mainfrom
Jah-yee:mlx-flip-handler

Conversation

@Jah-yee
Copy link
Copy Markdown

@Jah-yee Jah-yee commented Apr 18, 2026

Good day

Summary

This PR adds MLX op handlers for two PyTorch aten operators:

  • aten.flip: Reverse tensor elements along specified dimensions
  • aten.trunc: Truncate floating-point values toward zero

Implementation Details

aten.flip (issue #18918)

Uses with chained for each flip dimension:

For each dimension in the dims list, we emit a slice from to with step=-1. Chaining these slices achieves the flip operation without needing a dedicated MLX kernel.

aten.trunc (issue #18923)

Uses the mathematical decomposition:

Since MLX doesn't have a native trunc operation, we decompose it using existing nodes (GreaterEqualNode, FloorNode, CeilNode, WhereNode). This is the same approach used in other PyTorch decomposition strategies.

Testing

Both handlers use existing MLX nodes (SliceNode, IdCopyNode, GreaterEqualNode, FloorNode, CeilNode, WhereNode) — no new schema or runtime code needed.

References

Thank you for your attention. If there are any issues or suggestions, please leave a comment and I will address them promptly.

Warmly,
RoomWithOutRoof

- flip: decompose using SliceNode with step=-1 chained per dim
- trunc: use where(x>=0, floor(x), ceil(x)) since MLX lacks native trunc

Both are good-first-issues from pytorch/executorch.
Signed-off-by: RoomWithOutRoof <RoomWithOutRoof@users.noreply.github.com>
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 18, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18982

Note: Links to docs will display an error until the docs builds have been completed.

❌ 14 New Failures, 1 Cancelled Job, 2 Unrelated Failures

As of commit 1c4b0a1 with merge base 490ec5c (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 18, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@nil-is-all nil-is-all added the module: mlx Issues related to MLX Backend: Metal-accelerated inference on Apple Silicon label Apr 20, 2026
@nil-is-all
Copy link
Copy Markdown
Contributor

@claude review the PR

@claude
Copy link
Copy Markdown

claude Bot commented Apr 20, 2026

Claude finished @nil-is-all's task in 3m 52s —— View job


Review of PR #18982: Add handlers for aten.flip and aten.trunc

  • Read PR diff and changed files
  • Review existing MLX ops patterns for consistency
  • Analyze aten.flip handler implementation
  • Analyze aten.trunc handler implementation
  • Post review feedback

Overall

Both handlers follow the established patterns in backends/mlx/ops.py well — registry decoration, arg/kwarg validation, metadata access, temp slot usage, and node emission are all consistent with existing handlers. The mathematical decompositions are correct. I have a few observations, mostly minor.


aten.flip handler (backends/mlx/ops.py:447-484)

Correctness: The approach is sound. Chaining SliceNode with step=-1 for each dimension correctly implements flip. The start=dim_size - 1, stop=-(dim_size + 1) parameters reverse the full axis.

Issues / suggestions:

  1. Dynamic shape dimensions are silently baked in (minor, non-blocking): If x_meta.shape[dim_val] is a SymInt (dynamic shape), the int() conversion inside P.to_int_or_vid() will bake in the trace-time value, making the handler incorrect for inputs whose shape varies at runtime. Other handlers like _index_tensor_handler (ops.py:1841) explicitly check for this and raise a clear error. Consider adding a guard:

    dim_size = x_meta.shape[dim_val]
    if not isinstance(dim_size, int):
        raise ValueError(
            f"aten.flip: dimension {dim_val} has dynamic size {dim_size}, "
            "which is not supported"
        )

    Fix this →

  2. Negative dims (minor, non-blocking): aten.flip in PyTorch accepts negative dimension values (e.g., dims=[-1]). The code does x_meta.shape[dim_val] which works with negative indices on Python tuples, and P.to_int_or_vid(dim) will produce a negative IntOrVid. Whether the MLX runtime's SliceNode handles negative axis values depends on the runtime implementation — worth verifying this works in tests.

  3. Unnecessary IdCopyNode when dims is empty (nitpick): If dims is [], the handler emits an IdCopyNode that just copies input to output. This is functionally correct but produces a no-op node. Not worth changing unless performance of the graph construction matters.

  4. Unnecessary IdCopyNode in general (minor): The final IdCopyNode is always emitted to copy from the last tmp slot to the output slot. An alternative pattern would be to emit the final SliceNode directly into final_out (detecting the last iteration), avoiding an extra node. This is a minor optimization and the current approach is clearer, so this is non-blocking.


aten.trunc handler (backends/mlx/ops.py:487-535)

Correctness: The decomposition trunc(x) = where(x >= 0, floor(x), ceil(x)) is mathematically correct and a well-known approach.

Issues / suggestions:

  1. NaN propagation (worth noting): For NaN inputs, floor(NaN) and ceil(NaN) both produce NaN, and NaN >= 0 is false, so where selects ceil(NaN) = NaN. The result is correct — trunc(NaN) should return NaN. Good.

  2. Integer dtype handling (minor, non-blocking): If the input tensor has an integer dtype, trunc is a no-op (integers are already truncated). The current implementation will still emit GreaterEqual, Floor, Ceil, and Where nodes, which is wasteful but functionally correct since floor(int) == ceil(int) == int. PyTorch's aten.trunc typically doesn't get called on integer tensors through torch.export, so this is unlikely to matter in practice.

  3. emit_lifted_constant dtype (minor, non-blocking): At line 505, emit_lifted_constant(P, 0.0, dtype) creates a zero scalar with the input's dtype. If dtype happened to be an integer type, this would create 0 as that integer type, and GreaterEqualNode would compare integer vs integer — which is fine. No issue here, just noting the consistency.


Missing tests

The PR description mentions testing but the test file (backends/mlx/test/test_ops.py) doesn't appear to have corresponding test cases for flip or trunc. Adding FlipTest and TruncTest classes following the @register_test / OpTestCase pattern used by other tests would strengthen this PR. This is the most important gap.


Summary

The implementation is correct and follows existing patterns well. The main feedback:

# Severity Issue
1 Medium Add dynamic shape guard for flip dimension sizes
2 Low Verify negative dims work with MLX SliceNode axis
Tests Medium Add test cases for both ops in test_ops.py

The rest are nitpicks. Nice work on the decomposition approaches — using SliceNode chaining for flip and the where(x>=0, floor, ceil) pattern for trunc are clean solutions that avoid needing new MLX runtime ops.


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: mlx Issues related to MLX Backend: Metal-accelerated inference on Apple Silicon

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants