|
| 1 | +from typing import Callable |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from ignite.exceptions import NotComputableError |
| 6 | +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce |
| 7 | + |
| 8 | +__all__ = ["NDCG"] |
| 9 | + |
| 10 | + |
| 11 | +class NDCG(Metric): |
| 12 | + r"""Calculates the Normalized Discounted Cumulative Gain (NDCG) at `k` for Recommendation Systems. |
| 13 | +
|
| 14 | + NDCG measures the quality of ranking by considering both the relevance of items and their |
| 15 | + positions in the ranked list. It compares the achieved DCG against the ideal DCG (IDCG) |
| 16 | + obtained by sorting items by their true relevance. |
| 17 | +
|
| 18 | + .. math:: |
| 19 | + \text{NDCG}@K = \frac{\text{DCG}@K}{\text{IDCG}@K} |
| 20 | +
|
| 21 | + where: |
| 22 | +
|
| 23 | + .. math:: |
| 24 | + \text{DCG}@K = \sum_{i=1}^{K} \frac{2^{\text{rel}_i} - 1}{\log_2(i + 1)} |
| 25 | +
|
| 26 | + and :math:`\text{rel}_i` is the relevance score of the item at position :math:`i` in the |
| 27 | + ranked list (1-indexed). |
| 28 | +
|
| 29 | + - ``update`` must receive output of the form ``(y_pred, y)``. |
| 30 | + - ``y_pred`` is expected to be raw logits or probability score for each item in the catalog. |
| 31 | + - ``y`` is expected to contain relevance scores (can be binary or graded). |
| 32 | + - ``y_pred`` and ``y`` are only allowed shape :math:`(batch, num\_items)`. |
| 33 | + - returns a list of NDCG ordered by the sorted values of ``top_k``. |
| 34 | +
|
| 35 | + Args: |
| 36 | + top_k: a list of sorted positive integers that specifies `k` for calculating NDCG@top-k. |
| 37 | + ignore_zero_hits: if True, users with no relevant items (ground truth tensor being all zeros) |
| 38 | + are ignored in computation of NDCG. If set False, such users are counted with NDCG of 0. |
| 39 | + By default, True. |
| 40 | + relevance_threshold: minimum label value to be considered relevant. Defaults to ``1``, |
| 41 | + which handles standard binary labels and graded relevance scales (e.g. TREC-style |
| 42 | + 0-4) by treating any label >= 1 as relevant. Items below this threshold contribute |
| 43 | + 0 to DCG/IDCG calculations. |
| 44 | + output_transform: a callable that is used to transform the |
| 45 | + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the |
| 46 | + form expected by the metric. |
| 47 | + The output is expected to be a tuple `(prediction, target)` |
| 48 | + where `prediction` and `target` are tensors |
| 49 | + of shape ``(batch, num_items)``. |
| 50 | + device: specifies which device updates are accumulated on. Setting the |
| 51 | + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is |
| 52 | + non-blocking. By default, CPU. |
| 53 | + skip_unrolling: specifies whether input should be unrolled or not before being |
| 54 | + processed. Should be true for multi-output models.. |
| 55 | +
|
| 56 | + Examples: |
| 57 | + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. |
| 58 | + The output of the engine's ``process_function`` needs to be in the format of |
| 59 | + ``(y_pred, y)``. If not, ``output_tranform`` can be added |
| 60 | + to the metric to transform the output into the form expected by the metric. |
| 61 | +
|
| 62 | + For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. |
| 63 | +
|
| 64 | + .. include:: defaults.rst |
| 65 | + :start-after: :orphan: |
| 66 | +
|
| 67 | + ignore_zero_hits=True case |
| 68 | +
|
| 69 | + .. testcode:: 1 |
| 70 | +
|
| 71 | + metric = NDCG(top_k=[1, 2, 3, 4]) |
| 72 | + metric.attach(default_evaluator, "ndcg") |
| 73 | + y_pred=torch.Tensor([ |
| 74 | + [4.0, 2.0, 3.0, 1.0], |
| 75 | + [1.0, 2.0, 3.0, 4.0] |
| 76 | + ]) |
| 77 | + y_true=torch.Tensor([ |
| 78 | + [0.0, 0.0, 1.0, 1.0], |
| 79 | + [0.0, 0.0, 0.0, 0.0] |
| 80 | + ]) |
| 81 | + state = default_evaluator.run([(y_pred, y_true)]) |
| 82 | + print(state.metrics["ndcg"]) |
| 83 | +
|
| 84 | + .. testoutput:: 1 |
| 85 | +
|
| 86 | + [0.0, 0.63..., 0.63..., 0.63...] |
| 87 | +
|
| 88 | + ignore_zero_hits=False case |
| 89 | +
|
| 90 | + .. testcode:: 2 |
| 91 | +
|
| 92 | + metric = NDCG(top_k=[1, 2, 3, 4], ignore_zero_hits=False) |
| 93 | + metric.attach(default_evaluator, "ndcg") |
| 94 | + y_pred=torch.Tensor([ |
| 95 | + [4.0, 2.0, 3.0, 1.0], |
| 96 | + [1.0, 2.0, 3.0, 4.0] |
| 97 | + ]) |
| 98 | + y_true=torch.Tensor([ |
| 99 | + [0.0, 0.0, 1.0, 1.0], |
| 100 | + [0.0, 0.0, 0.0, 0.0] |
| 101 | + ]) |
| 102 | + state = default_evaluator.run([(y_pred, y_true)]) |
| 103 | + print(state.metrics["ndcg"]) |
| 104 | +
|
| 105 | + .. testoutput:: 2 |
| 106 | +
|
| 107 | + [0.0, 0.31..., 0.31..., 0.31...] |
| 108 | +
|
| 109 | + .. versionadded:: 0.6.0 |
| 110 | + """ |
| 111 | + |
| 112 | + required_output_keys = ("y_pred", "y") |
| 113 | + _state_dict_all_req_keys = ("_sum_ndcg_per_k", "_num_examples") |
| 114 | + |
| 115 | + def __init__( |
| 116 | + self, |
| 117 | + top_k: list[int], |
| 118 | + ignore_zero_hits: bool = True, |
| 119 | + relevance_threshold: float = 1.0, |
| 120 | + output_transform: Callable = lambda x: x, |
| 121 | + device: str | torch.device = torch.device("cpu"), |
| 122 | + skip_unrolling: bool = False, |
| 123 | + ): |
| 124 | + if any(k <= 0 for k in top_k): |
| 125 | + raise ValueError(" top_k must be list of positive integers only.") |
| 126 | + |
| 127 | + self.top_k = sorted(top_k) |
| 128 | + self.ignore_zero_hits = ignore_zero_hits |
| 129 | + self.relevance_threshold = relevance_threshold |
| 130 | + super(NDCG, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling) |
| 131 | + |
| 132 | + @reinit__is_reduced |
| 133 | + def reset(self) -> None: |
| 134 | + self._sum_ndcg_per_k = torch.zeros(len(self.top_k), device=self._device) |
| 135 | + self._num_examples = 0 |
| 136 | + |
| 137 | + def _compute_dcg(self, relevance_scores: torch.Tensor, k: int) -> torch.Tensor: |
| 138 | + """Compute DCG@k for a batch of relevance scores. |
| 139 | + |
| 140 | + Args: |
| 141 | + relevance_scores: Tensor of shape (batch, num_items) with relevance scores at ranked positions |
| 142 | + k: Number of positions to consider |
| 143 | + |
| 144 | + Returns: |
| 145 | + DCG scores of shape (batch,) |
| 146 | + """ |
| 147 | + # Handle case where k > actual number of items |
| 148 | + actual_k = min(k, relevance_scores.shape[1]) |
| 149 | + |
| 150 | + # Create position weights: 1/log2(position + 1) for position in [1, actual_k] |
| 151 | + # Positions are 1-indexed in the DCG formula |
| 152 | + positions = torch.arange(1, actual_k + 1, dtype=torch.float32, device=relevance_scores.device) |
| 153 | + discounts = 1.0 / torch.log2(positions + 1) # log2(i+1) for i in [1, actual_k] |
| 154 | + |
| 155 | + # Compute gains: 2^rel - 1 |
| 156 | + gains = torch.pow(2.0, relevance_scores[:, :actual_k]) - 1.0 |
| 157 | + |
| 158 | + # DCG = sum of (gain / discount) |
| 159 | + dcg = (gains * discounts).sum(dim=-1) |
| 160 | + return dcg |
| 161 | + |
| 162 | + @reinit__is_reduced |
| 163 | + def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: |
| 164 | + if len(output) != 2: |
| 165 | + raise ValueError(f"output should be in format `(y_pred,y)` but got tuple of {len(output)} tensors.") |
| 166 | + |
| 167 | + y_pred, y = output |
| 168 | + if y_pred.shape != y.shape: |
| 169 | + raise ValueError(f"y_pred and y must be in the same shape, got {y_pred.shape} != {y.shape}.") |
| 170 | + |
| 171 | + # Filter out examples with no relevant items if ignore_zero_hits is True |
| 172 | + if self.ignore_zero_hits: |
| 173 | + valid_mask = torch.any(y >= self.relevance_threshold, dim=-1) |
| 174 | + y_pred = y_pred[valid_mask] |
| 175 | + y = y[valid_mask] |
| 176 | + |
| 177 | + if y.shape[0] == 0: |
| 178 | + return |
| 179 | + |
| 180 | + # Zero out items below relevance threshold for DCG computation |
| 181 | + y_for_dcg = torch.where(y >= self.relevance_threshold, y, torch.zeros_like(y)) |
| 182 | + |
| 183 | + max_k = self.top_k[-1] |
| 184 | + |
| 185 | + # Get ranked indices based on predictions (stable=True for deterministic tie-breaking) |
| 186 | + ranked_indices = torch.argsort(y_pred, dim=-1, descending=True, stable=True)[:, :max_k] |
| 187 | + |
| 188 | + # Get relevance scores in the predicted ranking order |
| 189 | + ranked_relevance = torch.gather(y_for_dcg, dim=-1, index=ranked_indices) |
| 190 | + |
| 191 | + # Compute ideal ranking by sorting true relevance scores |
| 192 | + ideal_relevance = torch.sort(y_for_dcg, dim=-1, descending=True, stable=True)[0][:, :max_k] |
| 193 | + |
| 194 | + for i, k in enumerate(self.top_k): |
| 195 | + # Compute DCG@k and IDCG@k |
| 196 | + dcg_k = self._compute_dcg(ranked_relevance, k) |
| 197 | + idcg_k = self._compute_dcg(ideal_relevance, k) |
| 198 | + |
| 199 | + # NDCG = DCG / IDCG, handle division by zero (when IDCG = 0, NDCG = 0) |
| 200 | + ndcg_k = torch.where( |
| 201 | + idcg_k > 0, |
| 202 | + dcg_k / idcg_k, |
| 203 | + torch.zeros_like(dcg_k) |
| 204 | + ) |
| 205 | + |
| 206 | + self._sum_ndcg_per_k[i] += ndcg_k.sum().to(self._device) |
| 207 | + |
| 208 | + self._num_examples += y.shape[0] |
| 209 | + |
| 210 | + @sync_all_reduce("_sum_ndcg_per_k", "_num_examples") |
| 211 | + def compute(self) -> list[float]: |
| 212 | + if self._num_examples == 0: |
| 213 | + raise NotComputableError("NDCG must have at least one example.") |
| 214 | + |
| 215 | + ndcg_scores = (self._sum_ndcg_per_k / self._num_examples).tolist() |
| 216 | + return ndcg_scores |
0 commit comments