diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 3f7da88a793..23377686a65 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -444,6 +444,41 @@ def _isnan_handler(P: MLXProgramBuilder, n: Node) -> Slot: return out +@REGISTRY.register(target=[torch.ops.aten.isinf.default]) +def _isinf_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.isinf - check for infinite values element-wise. + + isinf(x) is equivalent to abs(x) == inf. + """ + args = P.args(n) + require_args(args, 1, 1, "aten.isinf") + require_kwargs(P.kwargs(n), set(), "aten.isinf") + x = args[0] + + # Create abs(x) + _, abs_tmp = P.make_tmp_slot() + P.emit( + AbsNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(abs_tmp), + ) + ) + + # Create inf constant (float32; EqualNode handles type promotion to match input dtype) + inf_slot = emit_lifted_constant(P, float('inf'), torch.float32) + + # Compare abs(x) == inf + out = P.make_or_get_slot(n) + P.emit( + EqualNode( + a=P.slot_to_tid(abs_tmp), + b=P.slot_to_tid(inf_slot), + out=P.slot_to_tid(out), + ) + ) + return out + + _BINARY_OPS: List[Tuple[List[Any], Any, str, bool]] = [ ( [torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar], diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 7ba3902e436..60119dd4a60 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -4020,6 +4020,21 @@ def fn(shape, dtype): return fn +def _inf_input_fn(): + """Return a callable(shape, dtype) that generates inputs with some inf values.""" + + def fn(shape, dtype): + x = torch.randn(shape, dtype=dtype) + # Insert ~20% +inf and ~10% -inf using non-overlapping masks + mask_pos = torch.rand(shape) > 0.8 # ~20% -> +inf + mask_neg = (~mask_pos) & (torch.rand(shape) > 0.9) # ~10% of remaining -> -inf + x[mask_pos] = float('inf') + x[mask_neg] = float('-inf') + return (x,) + + return fn + + # Standard shape and dtype configs used by unary tests. _SHAPES_3 = [(16,), (4, 4), (2, 3, 4)] _SHAPES_2 = [(16,), (4, 4)] @@ -4112,6 +4127,7 @@ def create_model(self) -> nn.Module: {"op_name": "neg", "op_fn": torch.neg}, {"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()}, {"op_name": "isnan", "op_fn": torch.isnan, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _nan_input_fn()}, + {"op_name": "isinf", "op_fn": torch.isinf, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _inf_input_fn()}, # activations {"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)}, {"op_name": "sigmoid", "op_fn": torch.sigmoid, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2)},