diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 3f7da88a793..85cc5b10d36 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -444,6 +444,97 @@ def _isnan_handler(P: MLXProgramBuilder, n: Node) -> Slot: return out +@REGISTRY.register(target=[torch.ops.aten.flip.default]) +def _flip_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.flip - reverse tensor along specified dimensions. + + Decomposed using SliceNode with step=-1, chained for each flip dim. + For example, flip(x, dims=[0, 2]) chains: + slice(x, axis=0, start=size-1, stop=-(size+1), step=-1) + then slice(result, axis=2, start=size-1, stop=-(size+1), step=-1) + """ + args = P.args(n) + require_args(args, 2, 2, "aten.flip") + require_kwargs(P.kwargs(n), set(), "aten.flip") + x, dims = args + + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise ValueError("Input tensor metadata not found for aten.flip") + + out = x # Start with input, chain slices + for dim in dims: + dim_val = int(dim) + dim_size = x_meta.shape[dim_val] + _, tmp = P.make_tmp_slot() + P.emit( + SliceNode( + x=P.slot_to_tid(out), + out=P.slot_to_tid(tmp), + axis=P.to_int_or_vid(dim), + start=P.to_int_or_vid(dim_size - 1), + stop=P.to_int_or_vid(-(dim_size + 1)), + step=-1, + ) + ) + out = tmp + + final_out = P.make_or_get_slot(n) + P.emit(IdCopyNode(x=P.slot_to_tid(out), out=P.slot_to_tid(final_out))) + return final_out + + +@REGISTRY.register(target=[torch.ops.aten.trunc.default]) +def _trunc_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.trunc - truncate toward zero. + + Uses the mathematical property: + trunc(x) = where(x >= 0, floor(x), ceil(x)) + + This is needed because MLX doesn't have a native trunc operation. + """ + args = P.args(n) + require_args(args, 1, 1, "aten.trunc") + require_kwargs(P.kwargs(n), set(), "aten.trunc") + x = args[0] + + x_meta = n.args[0].meta.get("val") + dtype = x_meta.dtype if x_meta is not None else torch.float32 + + # Create zero constant for comparison + zero_slot = emit_lifted_constant(P, 0.0, dtype) + + # x >= 0 + _, ge_zero = P.make_tmp_slot() + P.emit( + GreaterEqualNode( + a=P.slot_to_tid(x), + b=P.slot_to_tid(zero_slot), + out=P.slot_to_tid(ge_zero), + ) + ) + + # floor(x) + _, floor_x = P.make_tmp_slot() + P.emit(FloorNode(x=P.slot_to_tid(x), out=P.slot_to_tid(floor_x))) + + # ceil(x) + _, ceil_x = P.make_tmp_slot() + P.emit(CeilNode(x=P.slot_to_tid(x), out=P.slot_to_tid(ceil_x))) + + # where(x >= 0, floor(x), ceil(x)) + out = P.make_or_get_slot(n) + P.emit( + WhereNode( + condition=P.slot_to_tid(ge_zero), + x=P.slot_to_tid(floor_x), + y=P.slot_to_tid(ceil_x), + out=P.slot_to_tid(out), + ) + ) + return out + + _BINARY_OPS: List[Tuple[List[Any], Any, str, bool]] = [ ( [torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar],