diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 53ab6ff0ecfb..2261498c8be0 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -396,7 +396,7 @@ Complete list of metrics fairness.SubgroupDifference fairness.SubgroupMetric rec_sys.HitRate - + rec_sys.NDCG .. note:: diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 2175221a8416..b1813cc92935 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -38,6 +38,7 @@ from ignite.metrics.psnr import PSNR from ignite.metrics.recall import Recall from ignite.metrics.rec_sys.hitrate import HitRate +from ignite.metrics.rec_sys.ndcg import NDCG from ignite.metrics.roc_auc import ROC_AUC, RocCurve from ignite.metrics.root_mean_squared_error import RootMeanSquaredError from ignite.metrics.running_average import RunningAverage @@ -106,4 +107,5 @@ "CommonObjectDetectionMetrics", "coco_tensor_list_to_dict_list", "HitRate", + "NDCG", ] diff --git a/ignite/metrics/rec_sys/__init__.py b/ignite/metrics/rec_sys/__init__.py index f6f37785cb4e..6fa17fdca928 100644 --- a/ignite/metrics/rec_sys/__init__.py +++ b/ignite/metrics/rec_sys/__init__.py @@ -1 +1,4 @@ from ignite.metrics.rec_sys.hitrate import HitRate +from ignite.metrics.rec_sys.ndcg import NDCG + +__all__ = ["HitRate", "NDCG"] diff --git a/ignite/metrics/rec_sys/ndcg.py b/ignite/metrics/rec_sys/ndcg.py new file mode 100644 index 000000000000..b9ac49aeaeac --- /dev/null +++ b/ignite/metrics/rec_sys/ndcg.py @@ -0,0 +1,234 @@ +from typing import Callable + +import torch + +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + +__all__ = ["NDCG"] + + +def _get_ranked_items(scores: torch.Tensor, items: torch.Tensor, k: int) -> torch.Tensor: + """Get top-k items ranked by scores.""" + indices = torch.argsort(scores, dim=-1, descending=True, stable=True)[:, :k] + return torch.gather(items, dim=-1, index=indices) + + +class NDCG(Metric): + r"""Calculates the Normalized Discounted Cumulative Gain (NDCG) at `k` for Recommendation Systems. + + For a step-by-step guide on how to use this metric, please refer to the + `NDCG Tutorial ` + + NDCG measures the quality of ranking by considering both the relevance of items and their + positions in the ranked list. It compares the achieved DCG against the ideal DCG (IDCG) + obtained by sorting items by their true relevance. + + .. math:: + \text{NDCG}@K = \frac{\text{DCG}@K}{\text{IDCG}@K} + + where: + + .. math:: + \text{DCG}@K = \sum_{i=1}^{K} \frac{2^{\text{rel}_i} - 1}{\log_2(i + 1)} + + and :math:`\text{rel}_i` is the relevance score of the item at position :math:`i` in the + ranked list (1-indexed). + + - ``update`` must receive output of the form ``(y_pred, y)``. + - ``y_pred`` is expected to be raw logits or probability score for each item in the catalog. + - ``y`` is expected to contain relevance scores (can be binary or graded). + + Relevance Types: + - **Binary relevance**: Labels are either 0 (not relevant) or 1 (relevant) + - **Graded relevance**: Labels can have multiple levels (e.g., 0-4 scale) + + Common graded scales: + - 0: Not relevant + - 1: Marginally relevant + - 2: Relevant + - 3: Highly relevant + - 4: Perfectly relevant + + The NDCG formula handles both types through the gain function: 2^relevance - 1. + Higher relevance scores contribute more to the metric. + + - ``y_pred`` and ``y`` are only allowed shape :math:`(batch, num\_items)`. + - returns a list of NDCG ordered by the sorted values of ``top_k``. + + Args: + top_k: a list of sorted positive integers that specifies `k` for calculating NDCG@top-k. + ignore_zero_hits: if True, users with no relevant items (ground truth tensor being all zeros) + are ignored in computation of NDCG. If set False, such users are counted with NDCG of 0. + By default, True. + relevance_threshold: minimum label value to be considered relevant. Defaults to ``1``, + which handles standard binary labels and graded relevance scales (e.g. TREC-style + 0-4) by treating any label >= 1 as relevant. Items below this threshold contribute + 0 to DCG/IDCG calculations. + gain_function (str): Gain function for relevance scores. Options: + - ``'exp_rank'``: 2^relevance - 1 (emphasizes high relevance, default) + - ``'linear_rank'``: relevance (simpler, linear scale) + Defaults to ``'exp_rank'``. + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. + The output is expected to be a tuple `(prediction, target)` + where `prediction` and `target` are tensors + of shape ``(batch, num_items)``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + skip_unrolling: specifies whether input should be unrolled or not before being + processed. Should be true for multi-output models.. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in the format of + ``(y_pred, y)``. If not, ``output_tranform`` can be added + to the metric to transform the output into the form expected by the metric. + + For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. + + .. include:: defaults.rst + :start-after: :orphan: + + ignore_zero_hits=True case + + .. testcode:: 1 + + metric = NDCG(top_k=[1, 2, 3, 4]) + metric.attach(default_evaluator, "ndcg") + y_pred=torch.Tensor([ + [4.0, 2.0, 3.0, 1.0], + [1.0, 2.0, 3.0, 4.0] + ]) + y_true=torch.Tensor([ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0] + ]) + state = default_evaluator.run([(y_pred, y_true)]) + print(state.metrics["ndcg"]) + + .. testoutput:: 1 + + [0.0, 0.38..., 0.38..., 0.65...] + + ignore_zero_hits=False case + + .. testcode:: 2 + + metric = NDCG(top_k=[1, 2, 3, 4], ignore_zero_hits=False) + metric.attach(default_evaluator, "ndcg") + y_pred=torch.Tensor([ + [4.0, 2.0, 3.0, 1.0], + [1.0, 2.0, 3.0, 4.0] + ]) + y_true=torch.Tensor([ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0] + ]) + state = default_evaluator.run([(y_pred, y_true)]) + print(state.metrics["ndcg"]) + + .. testoutput:: 2 + + [0.0, 0.19..., 0.19..., 0.32...] + + .. versionadded:: 0.6.0 + """ + + required_output_keys = ("y_pred", "y") + _state_dict_all_req_keys = ("_sum_ndcg_per_k", "_num_examples") + + def __init__( + self, + top_k: list[int], + ignore_zero_hits: bool = True, + relevance_threshold: float = 1.0, + gain_function: str = "exp_rank", + output_transform: Callable = lambda x: x, + device: str | torch.device = torch.device("cpu"), + skip_unrolling: bool = False, + ): + if any(k <= 0 for k in top_k): + raise ValueError(f"top_k must be list of positive integers only, but given {top_k}") + + if gain_function not in ["exp_rank", "linear_rank"]: + raise ValueError(f"gain_function must be either 'exp_rank' or 'linear_rank', but given {gain_function}") + + self.top_k = sorted(top_k) + self.ignore_zero_hits = ignore_zero_hits + self.relevance_threshold = relevance_threshold + self.gain_function = gain_function + super(NDCG, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling) + + @reinit__is_reduced + def reset(self) -> None: + self._sum_ndcg_per_k = torch.zeros(len(self.top_k), device=self._device) + self._num_examples = 0 + + def _compute_dcg(self, relevance_scores: torch.Tensor, k: int) -> torch.Tensor: + """Compute DCG@k for a batch of relevance scores.""" + actual_k = min(k, relevance_scores.shape[1]) + + positions = torch.arange(1, actual_k + 1, dtype=torch.float32, device=relevance_scores.device) + discounts = 1.0 / torch.log2(positions + 1) + + topk_relevance = relevance_scores[:, :actual_k] + if self.gain_function == "exp_rank": + gains = torch.pow(2.0, topk_relevance) - 1 + else: + gains = topk_relevance + + dcg = (gains * discounts).sum(dim=-1) + return dcg + + @reinit__is_reduced + def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: + if len(output) != 2: + raise ValueError(f"output should be in format `(y_pred,y)` but got tuple of {len(output)} tensors.") + + y_pred, y = output + if y_pred.shape != y.shape: + raise ValueError(f"y_pred and y must be in the same shape, got {y_pred.shape} != {y.shape}.") + + if self.ignore_zero_hits: + valid_mask = torch.any(y >= self.relevance_threshold, dim=-1) + y_pred = y_pred[valid_mask] + y = y[valid_mask] + + if y.shape[0] == 0: + return + + y_for_dcg = torch.where(y >= self.relevance_threshold, y, 0) + + max_k = self.top_k[-1] + ranked_relevance = _get_ranked_items(y_pred, y_for_dcg, max_k) + # Compute ideal ranking by sorting true relevance scores (descending). + # This aligns with standard IDCG computation in reference libraries: + # ranx: https://github.com/AmenRa/ranx/blob/master/ranx/metrics/ndcg.py#L52 + # catalyst: https://github.com/catalyst-team/catalyst/blob/master/catalyst/metrics/functional/_ndcg.py#L197 + + ideal_relevance = torch.sort(y_for_dcg, dim=-1, descending=True, stable=True)[0][:, :max_k] + + for i, k in enumerate(self.top_k): + dcg_k = self._compute_dcg(ranked_relevance, k) + idcg_k = self._compute_dcg(ideal_relevance, k) + + ndcg_k = torch.where( + idcg_k > 0, + dcg_k / idcg_k, + torch.zeros_like(dcg_k), + ) + + self._sum_ndcg_per_k[i] += ndcg_k.sum().to(self._device) + + self._num_examples += y.shape[0] + + @sync_all_reduce("_sum_ndcg_per_k", "_num_examples") + def compute(self) -> list[float]: + if self._num_examples == 0: + raise NotComputableError("NDCG must have at least one example.") + + rates = (self._sum_ndcg_per_k / self._num_examples).tolist() + return rates diff --git a/requirements-dev.txt b/requirements-dev.txt index b9bece494ea7..64d645fad721 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -35,3 +35,4 @@ pandas gymnasium # temporary fix: E AttributeError: module 'mpmath' has no attribute 'rational' mpmath<1.4 +ranx diff --git a/tests/ignite/metrics/rec_sys/test_ndcg.py b/tests/ignite/metrics/rec_sys/test_ndcg.py new file mode 100644 index 000000000000..9c5943a36642 --- /dev/null +++ b/tests/ignite/metrics/rec_sys/test_ndcg.py @@ -0,0 +1,341 @@ +import numpy as np +import pytest +import torch + +import ignite.distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics.rec_sys.ndcg import NDCG +from ranx import Qrels, Run, evaluate + + +def ranx_ndcg( + y_pred: np.ndarray, + y: np.ndarray, + top_k: list[int], + ignore_zero_hits: bool = True, + relevance_threshold: float = 1.0, + gain_function: str = "exp_rank", +) -> list[float]: + """Reference NDCG implementation using ranx for verification. https://github.com/AmenRa/ranx""" + + sorted_top_k = sorted(top_k) + results = [] + + for k in sorted_top_k: + qrels_dict = {} + run_dict = {} + + for i, (scores, labels) in enumerate(zip(y_pred, y)): + qid = f"q{i}" + relevant = {f"d{j}": float(label) for j, label in enumerate(labels) if label >= relevance_threshold} + + if ignore_zero_hits and not relevant: + continue + + qrels_dict[qid] = relevant if relevant else {"d0": 0.0} + run_dict[qid] = {f"d{j}": float(s) for j, s in enumerate(scores)} + + if not qrels_dict: + results.append(0.0) + continue + + run_dict = {q: run_dict[q] for q in qrels_dict} + metric_name = f"{'ndcg_burges' if gain_function == 'exp_rank' else 'ndcg'}@{k}" + results.append(float(evaluate(Qrels(qrels_dict), Run(run_dict), metric_name))) + return results + + +def test_zero_sample(): + metric = NDCG(top_k=[1, 5]) + with pytest.raises(NotComputableError, match=r"NDCG must have at least one example"): + metric.compute() + + +def test_shape_mismatch(): + metric = NDCG(top_k=[1]) + y_pred = torch.randn(4, 10) + y = torch.ones(4, 5) # Mismatched items count + with pytest.raises(ValueError, match="y_pred and y must be in the same shape"): + metric.update((y_pred, y)) + + +def test_invalid_top_k(): + with pytest.raises(ValueError, match="positive integers"): + NDCG(top_k=[0]) + with pytest.raises(ValueError, match="positive integers"): + NDCG(top_k=[-1, 5]) + + +def test_invalid_gain_function(): + with pytest.raises(ValueError, match="gain_function must be either"): + NDCG(top_k=[1], gain_function="invalid") + + +@pytest.mark.parametrize("top_k", [[1], [1, 2, 4]]) +@pytest.mark.parametrize("ignore_zero_hits", [True, False]) +def test_compute(top_k, ignore_zero_hits, available_device): + metric = NDCG( + top_k=top_k, + ignore_zero_hits=ignore_zero_hits, + device=available_device, + ) + + y_pred = torch.tensor([[4.0, 2.0, 3.0, 1.0], [1.0, 2.0, 3.0, 4.0]]) + y_true = torch.tensor([[0, 0, 1.0, 1.0], [0, 0, 0.0, 0.0]]) + + metric.update((y_pred, y_true)) + res = metric.compute() + + expected = ranx_ndcg( + y_pred.numpy(), + y_true.numpy(), + top_k, + ignore_zero_hits=ignore_zero_hits, + ) + + assert isinstance(res, list) + assert len(res) == len(top_k) + np.testing.assert_allclose(res, expected, rtol=1e-5) + + +@pytest.mark.parametrize("num_queries", [1, 10, 100]) +@pytest.mark.parametrize("num_items", [5, 20, 100]) +@pytest.mark.parametrize("k", [1, 5, 10]) +@pytest.mark.parametrize("ignore_zero_hits", [True, False]) +def test_compute_vs_ranx(num_queries, num_items, k, ignore_zero_hits, available_device): + """Verify NDCG matches ranx across a wide range of input shapes and k values.""" + torch.manual_seed(42) + y_pred = torch.randn(num_queries, num_items) + y_true = torch.randint(0, 5, (num_queries, num_items)).float() + + metric = NDCG( + top_k=[k], + ignore_zero_hits=ignore_zero_hits, + device=available_device, + ) + metric.update((y_pred, y_true)) + + try: + res = metric.compute() + except NotComputableError: + res = [0.0] + + expected = ranx_ndcg( + y_pred.numpy(), + y_true.numpy(), + top_k=[k], + ignore_zero_hits=ignore_zero_hits, + ) + + np.testing.assert_allclose(res, expected, rtol=1e-5) + + +@pytest.mark.parametrize("top_k", [[3], [2, 5]]) +@pytest.mark.parametrize("ignore_zero_hits", [True, False]) +def test_compute_vs_ranx_with_ties(top_k, ignore_zero_hits, available_device): + """Validate tie handling against ranx with non-trivial tie cases.""" + y_pred = torch.tensor( + [ + [0.7, 0.7, 0.7, 0.5, 0.5], + [0.9, 0.9, 0.6, 0.6, 0.3], + [0.8, 0.8, 0.5, 0.1, 0.1], + ] + ) + y_true = torch.tensor( + [ + [3.0, 2.0, 1.0, 1.0, 0.0], + [4.0, 3.0, 2.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + ] + ) + + metric = NDCG( + top_k=top_k, + ignore_zero_hits=ignore_zero_hits, + gain_function="exp_rank", + device=available_device, + ) + metric.update((y_pred, y_true)) + res = metric.compute() + + expected_ranx = ranx_ndcg( + y_pred.numpy(), + y_true.numpy(), + top_k, + ignore_zero_hits=ignore_zero_hits, + ) + np.testing.assert_allclose(res, expected_ranx, rtol=1e-5) + + +@pytest.mark.parametrize("top_k", [[3], [2, 5]]) +@pytest.mark.parametrize("ignore_zero_hits", [True, False]) +def test_compute_vs_ranx_linear_gain(top_k, ignore_zero_hits, available_device): + """Validate linear gain mode against ranx.""" + y_pred = torch.tensor( + [ + [0.7, 0.7, 0.7, 0.5, 0.5], + [0.9, 0.9, 0.6, 0.6, 0.3], + ] + ) + y_true = torch.tensor( + [ + [3.0, 2.0, 1.0, 1.0, 0.0], + [4.0, 3.0, 2.0, 1.0, 0.0], + ] + ) + + metric = NDCG( + top_k=top_k, + ignore_zero_hits=ignore_zero_hits, + gain_function="linear_rank", + device=available_device, + ) + metric.update((y_pred, y_true)) + res = metric.compute() + + expected_ranx = ranx_ndcg( + y_pred.numpy(), + y_true.numpy(), + top_k, + ignore_zero_hits=ignore_zero_hits, + gain_function="linear_rank", + ) + + np.testing.assert_allclose(res, expected_ranx, rtol=1e-5) + + +def test_perfect_prediction(): + """Perfect ranking -> NDCG = 1.0.""" + metric = NDCG(top_k=[1, 3]) + y_pred = torch.tensor([[5.0, 3.0, 4.0, 1.0]]) + y_true = torch.tensor([[3.0, 1.0, 2.0, 0.0]]) + metric.update((y_pred, y_true)) + assert metric.compute() == pytest.approx([1.0, 1.0]) + + +def test_multiple_batches(): + """NDCG accumulates correctly across multiple update() calls.""" + metric = NDCG(top_k=[3], ignore_zero_hits=False) + + y_pred_1 = torch.tensor([[3.0, 2.0, 1.0]]) + y_true_1 = torch.tensor([[3.0, 2.0, 1.0]]) + + y_pred_2 = torch.tensor([[1.0, 2.0, 3.0]]) + y_true_2 = torch.tensor([[3.0, 2.0, 1.0]]) + + metric.update((y_pred_1, y_true_1)) + metric.update((y_pred_2, y_true_2)) + + expected_1 = ranx_ndcg(y_pred_1.numpy(), y_true_1.numpy(), [3], ignore_zero_hits=False, gain_function="exp_rank")[0] + expected_2 = ranx_ndcg(y_pred_2.numpy(), y_true_2.numpy(), [3], ignore_zero_hits=False, gain_function="exp_rank")[0] + expected = (expected_1 + expected_2) / 2.0 + + assert metric.compute() == pytest.approx([expected], rel=1e-5) + + +def test_all_zeros_relevance(): + """When all relevance is 0, IDCG=0, so NDCG should be 0 if ignore_zero_hits=False.""" + metric = NDCG(top_k=[2], ignore_zero_hits=False) + y_pred = torch.tensor([[5.0, 3.0, 4.0]]) + y_true = torch.tensor([[0.0, 0.0, 0.0]]) + metric.update((y_pred, y_true)) + assert metric.compute() == pytest.approx([0.0]) + + +def test_graded_relevance_threshold(): + """Labels >= relevance_threshold are considered, but contribute their full value to DCG.""" + metric = NDCG(top_k=[3], relevance_threshold=2.0) + + y_pred = torch.tensor([[0.9, 0.3, 0.7]]) + y_true = torch.tensor([[3.0, 1.0, 2.0]]) + metric.update((y_pred, y_true)) + + result = metric.compute() + assert result[0] == pytest.approx(1.0, rel=1e-5) + + +def test_accumulator_detached(available_device): + metric = NDCG(top_k=[1], device=available_device) + y_pred = torch.randn(4, 5, requires_grad=True) + y = torch.randint(0, 2, (4, 5)).float() + metric.update((y_pred, y)) + + assert metric._sum_ndcg_per_k.requires_grad is False + assert metric._sum_ndcg_per_k.is_leaf is True + + +def test_all_zero_targets_ignore(): + metric = NDCG(top_k=[1, 3], ignore_zero_hits=True) + y_pred = torch.randn(4, 5) + y = torch.zeros(4, 5) + metric.update((y_pred, y)) + with pytest.raises(NotComputableError): + metric.compute() + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_integration(self): + n_iters = 10 + batch_size = 4 + num_items = 20 + top_k = [1, 5] + + rank = idist.get_rank() + torch.manual_seed(42 + rank) + device = idist.device() + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + all_y_true = torch.randint(0, 4, (n_iters * batch_size, num_items)).float().to(device) + all_y_pred = torch.randn((n_iters * batch_size, num_items)).to(device) + + for ignore_zero_hits in [True, False]: + engine = Engine( + lambda e, i: ( + all_y_pred[i * batch_size : (i + 1) * batch_size], + all_y_true[i * batch_size : (i + 1) * batch_size], + ) + ) + m = NDCG( + top_k=top_k, + ignore_zero_hits=ignore_zero_hits, + device=metric_device, + ) + m.attach(engine, "ndcg") + + engine.run(range(n_iters), max_epochs=1) + + global_y_true = idist.all_gather(all_y_true).cpu().numpy() + global_y_pred = idist.all_gather(all_y_pred).cpu().numpy() + + res = engine.state.metrics["ndcg"] + + true_res = ranx_ndcg( + global_y_pred, + global_y_true, + top_k, + ignore_zero_hits=ignore_zero_hits, + ) + + assert isinstance(res, list) + assert res == pytest.approx(true_res, rel=1e-5) + + engine.state.metrics.clear() + + def test_accumulator_device(self): + device = idist.device() + metric = NDCG(top_k=[1, 5], device=device) + + assert metric._device == device + assert metric._sum_ndcg_per_k.device == device + + y_pred = torch.randn(2, 10) + y = torch.zeros(2, 10) + metric.update((y_pred, y)) + + assert metric._sum_ndcg_per_k.device == device