From 4edd4d953d99688e174f09f9d90e6d9be709acad Mon Sep 17 00:00:00 2001 From: RoomWithOutRoof Date: Sat, 18 Apr 2026 19:12:37 +0800 Subject: [PATCH] Add MLX op handler for aten.trunc Add a decomposed handler for aten.trunc that uses existing MLX nodes (floor, ceil, where, greater_equal) to implement truncation toward zero on Metal GPU. Fixes: pytorch/executorch#18923 --- backends/mlx/ops.py | 58 +++++++++++++++++++++++++++++++++++ backends/mlx/test/test_ops.py | 1 + 2 files changed, 59 insertions(+) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 4dc891ee984..6b926b76efe 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -2851,6 +2851,64 @@ def _clamp_handler(P: MLXProgramBuilder, n: Node) -> Slot: return out +@REGISTRY.register(target=[torch.ops.aten.trunc.default]) +def _trunc_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.trunc - truncate toward zero. + + trunc(x) = where(x >= 0, floor(x), ceil(x)) + """ + args = P.args(n) + require_args(args, 1, 1, "aten.trunc") + require_kwargs(P.kwargs(n), set(), "aten.trunc") + x = args[0] + + x_meta = n.args[0].meta.get("val") + dtype = x_meta.dtype if x_meta is not None else torch.float32 + + # Create zero constant for comparison + zero_slot = emit_lifted_constant(P, 0.0, dtype) + + # x >= 0 + _, ge_zero = P.make_tmp_slot() + P.emit( + GreaterEqualNode( + a=P.slot_to_tid(x), + b=P.slot_to_tid(zero_slot), + out=P.slot_to_tid(ge_zero), + ) + ) + + # floor(x) + _, floor_x = P.make_tmp_slot() + P.emit( + FloorNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(floor_x), + ) + ) + + # ceil(x) + _, ceil_x = P.make_tmp_slot() + P.emit( + CeilNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(ceil_x), + ) + ) + + # where(x >= 0, floor(x), ceil(x)) + out = P.make_or_get_slot(n) + P.emit( + WhereNode( + condition=P.slot_to_tid(ge_zero), + x=P.slot_to_tid(floor_x), + y=P.slot_to_tid(ceil_x), + out=P.slot_to_tid(out), + ) + ) + return out + + @REGISTRY.register( target=[torch.ops.aten.expand.default, torch.ops.aten.expand_copy.default] ) diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index e5ece4931b9..2b725a369cc 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -4090,6 +4090,7 @@ def create_model(self) -> nn.Module: {"op_name": "erf", "op_fn": torch.erf}, {"op_name": "expm1", "op_fn": torch.expm1}, {"op_name": "round", "op_fn": torch.round, "input_fn": _input_fn(scale=10)}, + {"op_name": "trunc", "op_fn": torch.trunc, "shapes": _SHAPES_3, "input_fn": _input_fn(scale=10)}, {"op_name": "reciprocal", "op_fn": torch.reciprocal, "input_fn": _input_fn(offset=1.0)}, {"op_name": "sqrt", "op_fn": torch.sqrt, "input_fn": _input_fn(uniform=True, offset=0.1)}, {"op_name": "abs", "op_fn": torch.abs},