diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 53ab6ff0ecfb..2a7cc536c439 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -396,6 +396,7 @@ Complete list of metrics fairness.SubgroupDifference fairness.SubgroupMetric rec_sys.HitRate + rec_sys.MRR .. note:: diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 2175221a8416..e3c7238433de 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -38,6 +38,7 @@ from ignite.metrics.psnr import PSNR from ignite.metrics.recall import Recall from ignite.metrics.rec_sys.hitrate import HitRate +from ignite.metrics.rec_sys.mrr import MRR from ignite.metrics.roc_auc import ROC_AUC, RocCurve from ignite.metrics.root_mean_squared_error import RootMeanSquaredError from ignite.metrics.running_average import RunningAverage @@ -106,4 +107,5 @@ "CommonObjectDetectionMetrics", "coco_tensor_list_to_dict_list", "HitRate", + "MRR", ] diff --git a/ignite/metrics/rec_sys/__init__.py b/ignite/metrics/rec_sys/__init__.py index f6f37785cb4e..6876625f6d98 100644 --- a/ignite/metrics/rec_sys/__init__.py +++ b/ignite/metrics/rec_sys/__init__.py @@ -1 +1,2 @@ from ignite.metrics.rec_sys.hitrate import HitRate +from ignite.metrics.rec_sys.mrr import MRR diff --git a/ignite/metrics/rec_sys/mrr.py b/ignite/metrics/rec_sys/mrr.py new file mode 100644 index 000000000000..d2a194a5e097 --- /dev/null +++ b/ignite/metrics/rec_sys/mrr.py @@ -0,0 +1,181 @@ +from typing import Callable + +import torch + +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + +__all__ = ["MRR"] + + +class MRR(Metric): + r"""Calculates the Mean Reciprocal Rank (MRR) at `k` for Recommendation Systems. + + MRR measures the average of the reciprocal of the rank of the first relevant item + in the predicted list. It is widely used in retrieval systems, recommendation systems, + and RAG pipelines. + + .. math:: \text{MRR}@K = \frac{1}{N} \sum_{i=1}^{N} \frac{1}{\text{rank}_i} + + where :math:`\text{rank}_i` is the rank (1-indexed) of the first relevant item + in the top-K predictions for user :math:`i`. If no relevant item is found in the + top-K, the reciprocal rank for that user is 0. + + - ``update`` must receive output of the form ``(y_pred, y)``. + - ``y_pred`` is expected to be raw logits or probability score for each item in the catalog. + - ``y`` is expected to be binary (only 0s and 1s) values where `1` indicates relevant item. + Graded relevance labels are also supported via ``relevance_threshold``. + - ``y_pred`` and ``y`` are only allowed shape :math:`(batch, num\_items)`. + - returns a list of MRR ordered by the sorted values of ``top_k``. + + Args: + top_k: a list of sorted positive integers that specifies `k` for calculating MRR@top-k. + ignore_zero_hits: if True, users with no relevant items (ground truth tensor being all zeros) + are ignored in computation of MRR. If set False, such users are counted as having + reciprocal rank of 0. By default, True. + relevance_threshold: minimum label value to be considered relevant. Defaults to ``1``, + which handles standard binary labels and graded relevance scales (e.g. TREC-style + 0-4) by treating any label >= 1 as relevant. + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. + The output is expected to be a tuple `(prediction, target)` + where `prediction` and `target` are tensors + of shape ``(batch, num_items)``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + skip_unrolling: specifies whether input should be unrolled or not before being + processed. Should be true for multi-output models.. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in the format of + ``(y_pred, y)``. If not, ``output_tranform`` can be added + to the metric to transform the output into the form expected by the metric. + + For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. + + .. include:: defaults.rst + :start-after: :orphan: + + ignore_zero_hits=True case + + .. testcode:: 1 + + metric = MRR(top_k=[1, 2, 3, 4]) + metric.attach(default_evaluator,"mrr") + y_pred=torch.Tensor([ + [4.0, 2.0, 3.0, 1.0], + [1.0, 2.0, 3.0, 4.0] + ]) + y_true=torch.Tensor([ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0] + ]) + state = default_evaluator.run([(y_pred, y_true)]) + print(state.metrics["mrr"]) + + .. testoutput:: 1 + + [0.0, 0.5, 0.5, 0.5] + + ignore_zero_hits=False case + + .. testcode:: 2 + + metric = MRR(top_k=[1, 2, 3, 4], ignore_zero_hits=False) + metric.attach(default_evaluator,"mrr") + y_pred=torch.Tensor([ + [4.0, 2.0, 3.0, 1.0], + [1.0, 2.0, 3.0, 4.0] + ]) + y_true=torch.Tensor([ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0] + ]) + state = default_evaluator.run([(y_pred, y_true)]) + print(state.metrics["mrr"]) + + .. testoutput:: 2 + + [0.0, 0.25, 0.25, 0.25] + + .. versionadded:: 0.6.0 + """ + + required_output_keys = ("y_pred", "y") + _state_dict_all_req_keys = ("_sum_reciprocal_ranks_per_k", "_num_examples") + + def __init__( + self, + top_k: list[int], + ignore_zero_hits: bool = True, + relevance_threshold: float = 1.0, + output_transform: Callable = lambda x: x, + device: str | torch.device = torch.device("cpu"), + skip_unrolling: bool = False, + ): + if any(k <= 0 for k in top_k): + raise ValueError(" top_k must be list of positive integers only.") + + self.top_k = sorted(top_k) + self.ignore_zero_hits = ignore_zero_hits + self.relevance_threshold = relevance_threshold + super(MRR, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling) + + @reinit__is_reduced + def reset(self) -> None: + self._sum_reciprocal_ranks_per_k = torch.zeros(len(self.top_k), device=self._device) + self._num_examples = 0 + + @reinit__is_reduced + def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: + if len(output) != 2: + raise ValueError(f"output should be in format `(y_pred,y)` but got tuple of {len(output)} tensors.") + + y_pred, y = output + if y_pred.shape != y.shape: + raise ValueError(f"y_pred and y must be in the same shape, got {y_pred.shape} != {y.shape}.") + + if self.ignore_zero_hits: + valid_mask = torch.any(y >= self.relevance_threshold, dim=-1) + y_pred = y_pred[valid_mask] + y = y[valid_mask] + + if y.shape[0] == 0: + return + + max_k = self.top_k[-1] + + # stable=True ensures deterministic tie-breaking, consistent with + # reference libraries such as ranx. + ranked_indices = torch.argsort(y_pred, dim=-1, descending=True, stable=True)[:, :max_k] + ranked_labels = torch.gather(y, dim=-1, index=ranked_indices) + + for i, k in enumerate(self.top_k): + top_k_labels = ranked_labels[:, :k] + relevant_mask = top_k_labels >= self.relevance_threshold + + has_hit = relevant_mask.any(dim=-1) + + # argmax on int tensor returns 0-based position of first True + first_hit_pos = relevant_mask.int().argmax(dim=-1) + + reciprocal_rank = torch.where( + has_hit, + 1.0 / (first_hit_pos.float() + 1.0), + torch.zeros_like(first_hit_pos, dtype=torch.float), + ) + + self._sum_reciprocal_ranks_per_k[i] += reciprocal_rank.sum().to(self._device) + + self._num_examples += y.shape[0] + + @sync_all_reduce("_sum_reciprocal_ranks_per_k", "_num_examples") + def compute(self) -> list[float]: + if self._num_examples == 0: + raise NotComputableError("MRR must have at least one example.") + + rates = (self._sum_reciprocal_ranks_per_k / self._num_examples).tolist() + return rates diff --git a/requirements-dev.txt b/requirements-dev.txt index b9bece494ea7..19cc3f604423 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,6 +10,7 @@ pytest-order pytest-timeout pytest-xdist setuptools +ranx fairlearn # Test contrib dependencies scipy diff --git a/tests/ignite/metrics/rec_sys/test_mrr_metric.py b/tests/ignite/metrics/rec_sys/test_mrr_metric.py new file mode 100644 index 000000000000..2832fc8b96da --- /dev/null +++ b/tests/ignite/metrics/rec_sys/test_mrr_metric.py @@ -0,0 +1,261 @@ +import numpy as np +import pytest +import torch + +import ignite.distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics.rec_sys.mrr import MRR + + +def ranx_mrr( + y_pred: np.ndarray, + y: np.ndarray, + top_k: list[int], + ignore_zero_hits: bool = True, +) -> list[float]: + """Reference MRR implementation using ranx for verification. https://github.com/AmenRa/ranx """ + from ranx import Qrels, Run, evaluate + + sorted_top_k = sorted(top_k) + results = [] + + for k in sorted_top_k: + qrels_dict = {} + run_dict = {} + + for i, (scores, labels) in enumerate(zip(y_pred, y)): + qid = f"q{i}" + relevant = {f"d{j}": int(label) for j, label in enumerate(labels) if label > 0} + + if ignore_zero_hits and not relevant: + continue + + qrels_dict[qid] = relevant if relevant else {f"d0": 0} + run_dict[qid] = {f"d{j}": float(s) for j, s in enumerate(scores)} + + if not qrels_dict: + results.append(0.0) + continue + + run_dict = {q: run_dict[q] for q in qrels_dict} + results.append(float(evaluate(Qrels(qrels_dict), Run(run_dict), f"mrr@{k}"))) + + return results + + +def test_zero_sample(): + metric = MRR(top_k=[1, 5]) + with pytest.raises(NotComputableError, match=r"MRR must have at least one example"): + metric.compute() + + +def test_shape_mismatch(): + metric = MRR(top_k=[1]) + y_pred = torch.randn(4, 10) + y = torch.ones(4, 5) # Mismatched items count + with pytest.raises(ValueError, match="y_pred and y must be in the same shape"): + metric.update((y_pred, y)) + + +def test_invalid_top_k(): + with pytest.raises(ValueError, match="positive integers"): + MRR(top_k=[0]) + with pytest.raises(ValueError, match="positive integers"): + MRR(top_k=[-1, 5]) + + +@pytest.mark.parametrize("top_k", [[1], [1, 2, 4]]) +@pytest.mark.parametrize("ignore_zero_hits", [True, False]) +def test_compute(top_k, ignore_zero_hits, available_device): + metric = MRR( + top_k=top_k, + ignore_zero_hits=ignore_zero_hits, + device=available_device, + ) + + y_pred = torch.tensor([[4.0, 2.0, 3.0, 1.0], [1.0, 2.0, 3.0, 4.0]]) + y_true = torch.tensor([[0, 0, 1.0, 1.0], [0, 0, 0.0, 0.0]]) + + metric.update((y_pred, y_true)) + res = metric.compute() + + expected = ranx_mrr( + y_pred.numpy(), + y_true.numpy(), + top_k, + ignore_zero_hits=ignore_zero_hits, + ) + + assert isinstance(res, list) + assert len(res) == len(top_k) + np.testing.assert_allclose(res, expected) + + +@pytest.mark.parametrize("num_queries", [1, 10, 100]) +@pytest.mark.parametrize("num_items", [5, 20, 100]) +@pytest.mark.parametrize("k", [1, 5, 10]) +@pytest.mark.parametrize("ignore_zero_hits", [True, False]) +def test_compute_vs_ranx(num_queries, num_items, k, ignore_zero_hits, available_device): + """Verify MRR matches ranx across a wide range of input shapes and k values.""" + torch.manual_seed(42) + y_pred = torch.randn(num_queries, num_items) + y_true = torch.randint(0, 2, (num_queries, num_items)).float() + + metric = MRR( + top_k=[k], + ignore_zero_hits=ignore_zero_hits, + device=available_device, + ) + metric.update((y_pred, y_true)) + + try: + res = metric.compute() + except NotComputableError: + res = [0.0] + + expected = ranx_mrr( + y_pred.numpy(), + y_true.numpy(), + top_k=[k], + ignore_zero_hits=ignore_zero_hits, + ) + + np.testing.assert_allclose(res, expected, rtol=1e-5) + + +def test_known_values(): + """Test with manually computed expected values.""" + metric = MRR(top_k=[1, 2, 3, 4]) + # y_pred=[4,2,3,1] -> rank order: doc0, doc2, doc1, doc3 + # y=[0,0,1,1] -> relevance at ranked positions: [0,1,0,1] + # MRR@1: no hit at rank 1 -> 0 + # MRR@2: first hit at rank 2 -> 1/2 = 0.5 + # MRR@3: first hit at rank 2 -> 1/2 = 0.5 + # MRR@4: first hit at rank 2 -> 1/2 = 0.5 + y_pred = torch.tensor([[4.0, 2.0, 3.0, 1.0]]) + y_true = torch.tensor([[0.0, 0.0, 1.0, 1.0]]) + metric.update((y_pred, y_true)) + assert metric.compute() == pytest.approx([0.0, 0.5, 0.5, 0.5]) + + +def test_perfect_prediction(): + """Relevant item is top-ranked -> RR = 1.0.""" + metric = MRR(top_k=[1, 3]) + y_pred = torch.tensor([[5.0, 1.0, 2.0]]) + y_true = torch.tensor([[1.0, 0.0, 0.0]]) + metric.update((y_pred, y_true)) + assert metric.compute() == pytest.approx([1.0, 1.0]) + + +def test_multiple_batches(): + """RR accumulates correctly across multiple update() calls.""" + metric = MRR(top_k=[2]) + # batch 1: relevant at rank 2 -> RR = 0.5 + metric.update(( + torch.tensor([[4.0, 2.0, 3.0, 1.0]]), + torch.tensor([[0.0, 0.0, 1.0, 1.0]]), + )) + # batch 2: relevant at rank 1 -> RR = 1.0 + metric.update(( + torch.tensor([[5.0, 1.0, 2.0, 3.0]]), + torch.tensor([[1.0, 0.0, 0.0, 0.0]]), + )) + # MRR = (0.5 + 1.0) / 2 = 0.75 + assert metric.compute() == pytest.approx([0.75]) + + +def test_graded_relevance(): + """Labels >= relevance_threshold are treated as relevant.""" + # relevance_threshold=2: label=1 is NOT relevant, label=2 IS relevant + metric = MRR(top_k=[3], relevance_threshold=2.0) + y_pred = torch.tensor([[0.9, 0.7, 0.3]]) # rank order: d0, d1, d2 + y_true = torch.tensor([[1.0, 0.0, 2.0]]) # only d2 (rank 3) is relevant + metric.update((y_pred, y_true)) + assert metric.compute() == pytest.approx([1 / 3]) + + +def test_accumulator_detached(available_device): + metric = MRR(top_k=[1], device=available_device) + y_pred = torch.randn(4, 5, requires_grad=True) + y = torch.randint(0, 2, (4, 5)).float() + metric.update((y_pred, y)) + + assert metric._sum_reciprocal_ranks_per_k.requires_grad is False + assert metric._sum_reciprocal_ranks_per_k.is_leaf is True + + +def test_all_zero_targets_ignore(): + metric = MRR(top_k=[1, 3], ignore_zero_hits=True) + y_pred = torch.randn(4, 5) + y = torch.zeros(4, 5) + metric.update((y_pred, y)) + with pytest.raises(NotComputableError): + metric.compute() + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_integration(self): + n_iters = 10 + batch_size = 4 + num_items = 20 + top_k = [1, 5] + + rank = idist.get_rank() + torch.manual_seed(42 + rank) + device = idist.device() + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + all_y_true = torch.randint(0, 2, (n_iters * batch_size, num_items)).float().to(device) + all_y_pred = torch.randn((n_iters * batch_size, num_items)).to(device) + + for ignore_zero_hits in [True, False]: + engine = Engine( + lambda e, i: ( + all_y_pred[i * batch_size : (i + 1) * batch_size], + all_y_true[i * batch_size : (i + 1) * batch_size], + ) + ) + m = MRR( + top_k=top_k, + ignore_zero_hits=ignore_zero_hits, + device=metric_device, + ) + m.attach(engine, "mrr") + + engine.run(range(n_iters), max_epochs=1) + + global_y_true = idist.all_gather(all_y_true).cpu().numpy() + global_y_pred = idist.all_gather(all_y_pred).cpu().numpy() + + res = engine.state.metrics["mrr"] + + true_res = ranx_mrr( + global_y_pred, + global_y_true, + top_k, + ignore_zero_hits=ignore_zero_hits, + ) + + assert isinstance(res, list) + assert res == pytest.approx(true_res) + + engine.state.metrics.clear() + + def test_accumulator_device(self): + device = idist.device() + metric = MRR(top_k=[1, 5], device=device) + + assert metric._device == device + assert metric._sum_reciprocal_ranks_per_k.device == device + + y_pred = torch.randn(2, 10) + y = torch.zeros(2, 10) + metric.update((y_pred, y)) + + assert metric._sum_reciprocal_ranks_per_k.device == device