Skip to content

Commit a6edab5

Browse files
committed
Add NDCG metric to rec_sys
1 parent e56f4f9 commit a6edab5

File tree

2 files changed

+448
-0
lines changed

2 files changed

+448
-0
lines changed

ignite/metrics/rec_sys/ndcg.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
from typing import Callable
2+
3+
import torch
4+
5+
from ignite.exceptions import NotComputableError
6+
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
7+
8+
__all__ = ["NDCG"]
9+
10+
11+
class NDCG(Metric):
12+
r"""Calculates the Normalized Discounted Cumulative Gain (NDCG) at `k` for Recommendation Systems.
13+
14+
NDCG measures the quality of ranking by considering both the relevance of items and their
15+
positions in the ranked list. It compares the achieved DCG against the ideal DCG (IDCG)
16+
obtained by sorting items by their true relevance.
17+
18+
.. math::
19+
\text{NDCG}@K = \frac{\text{DCG}@K}{\text{IDCG}@K}
20+
21+
where:
22+
23+
.. math::
24+
\text{DCG}@K = \sum_{i=1}^{K} \frac{2^{\text{rel}_i} - 1}{\log_2(i + 1)}
25+
26+
and :math:`\text{rel}_i` is the relevance score of the item at position :math:`i` in the
27+
ranked list (1-indexed).
28+
29+
- ``update`` must receive output of the form ``(y_pred, y)``.
30+
- ``y_pred`` is expected to be raw logits or probability score for each item in the catalog.
31+
- ``y`` is expected to contain relevance scores (can be binary or graded).
32+
- ``y_pred`` and ``y`` are only allowed shape :math:`(batch, num\_items)`.
33+
- returns a list of NDCG ordered by the sorted values of ``top_k``.
34+
35+
Args:
36+
top_k: a list of sorted positive integers that specifies `k` for calculating NDCG@top-k.
37+
ignore_zero_hits: if True, users with no relevant items (ground truth tensor being all zeros)
38+
are ignored in computation of NDCG. If set False, such users are counted with NDCG of 0.
39+
By default, True.
40+
relevance_threshold: minimum label value to be considered relevant. Defaults to ``1``,
41+
which handles standard binary labels and graded relevance scales (e.g. TREC-style
42+
0-4) by treating any label >= 1 as relevant. Items below this threshold contribute
43+
0 to DCG/IDCG calculations.
44+
output_transform: a callable that is used to transform the
45+
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
46+
form expected by the metric.
47+
The output is expected to be a tuple `(prediction, target)`
48+
where `prediction` and `target` are tensors
49+
of shape ``(batch, num_items)``.
50+
device: specifies which device updates are accumulated on. Setting the
51+
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
52+
non-blocking. By default, CPU.
53+
skip_unrolling: specifies whether input should be unrolled or not before being
54+
processed. Should be true for multi-output models..
55+
56+
Examples:
57+
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
58+
The output of the engine's ``process_function`` needs to be in the format of
59+
``(y_pred, y)``. If not, ``output_tranform`` can be added
60+
to the metric to transform the output into the form expected by the metric.
61+
62+
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
63+
64+
.. include:: defaults.rst
65+
:start-after: :orphan:
66+
67+
ignore_zero_hits=True case
68+
69+
.. testcode:: 1
70+
71+
metric = NDCG(top_k=[1, 2, 3, 4])
72+
metric.attach(default_evaluator, "ndcg")
73+
y_pred=torch.Tensor([
74+
[4.0, 2.0, 3.0, 1.0],
75+
[1.0, 2.0, 3.0, 4.0]
76+
])
77+
y_true=torch.Tensor([
78+
[0.0, 0.0, 1.0, 1.0],
79+
[0.0, 0.0, 0.0, 0.0]
80+
])
81+
state = default_evaluator.run([(y_pred, y_true)])
82+
print(state.metrics["ndcg"])
83+
84+
.. testoutput:: 1
85+
86+
[0.0, 0.63..., 0.63..., 0.63...]
87+
88+
ignore_zero_hits=False case
89+
90+
.. testcode:: 2
91+
92+
metric = NDCG(top_k=[1, 2, 3, 4], ignore_zero_hits=False)
93+
metric.attach(default_evaluator, "ndcg")
94+
y_pred=torch.Tensor([
95+
[4.0, 2.0, 3.0, 1.0],
96+
[1.0, 2.0, 3.0, 4.0]
97+
])
98+
y_true=torch.Tensor([
99+
[0.0, 0.0, 1.0, 1.0],
100+
[0.0, 0.0, 0.0, 0.0]
101+
])
102+
state = default_evaluator.run([(y_pred, y_true)])
103+
print(state.metrics["ndcg"])
104+
105+
.. testoutput:: 2
106+
107+
[0.0, 0.31..., 0.31..., 0.31...]
108+
109+
.. versionadded:: 0.6.0
110+
"""
111+
112+
required_output_keys = ("y_pred", "y")
113+
_state_dict_all_req_keys = ("_sum_ndcg_per_k", "_num_examples")
114+
115+
def __init__(
116+
self,
117+
top_k: list[int],
118+
ignore_zero_hits: bool = True,
119+
relevance_threshold: float = 1.0,
120+
output_transform: Callable = lambda x: x,
121+
device: str | torch.device = torch.device("cpu"),
122+
skip_unrolling: bool = False,
123+
):
124+
if any(k <= 0 for k in top_k):
125+
raise ValueError(" top_k must be list of positive integers only.")
126+
127+
self.top_k = sorted(top_k)
128+
self.ignore_zero_hits = ignore_zero_hits
129+
self.relevance_threshold = relevance_threshold
130+
super(NDCG, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling)
131+
132+
@reinit__is_reduced
133+
def reset(self) -> None:
134+
self._sum_ndcg_per_k = torch.zeros(len(self.top_k), device=self._device)
135+
self._num_examples = 0
136+
137+
def _compute_dcg(self, relevance_scores: torch.Tensor, k: int) -> torch.Tensor:
138+
"""Compute DCG@k for a batch of relevance scores.
139+
140+
Args:
141+
relevance_scores: Tensor of shape (batch, num_items) with relevance scores at ranked positions
142+
k: Number of positions to consider
143+
144+
Returns:
145+
DCG scores of shape (batch,)
146+
"""
147+
# Handle case where k > actual number of items
148+
actual_k = min(k, relevance_scores.shape[1])
149+
150+
# Create position weights: 1/log2(position + 1) for position in [1, actual_k]
151+
# Positions are 1-indexed in the DCG formula
152+
positions = torch.arange(1, actual_k + 1, dtype=torch.float32, device=relevance_scores.device)
153+
discounts = 1.0 / torch.log2(positions + 1) # log2(i+1) for i in [1, actual_k]
154+
155+
# Compute gains: 2^rel - 1
156+
gains = torch.pow(2.0, relevance_scores[:, :actual_k]) - 1.0
157+
158+
# DCG = sum of (gain / discount)
159+
dcg = (gains * discounts).sum(dim=-1)
160+
return dcg
161+
162+
@reinit__is_reduced
163+
def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
164+
if len(output) != 2:
165+
raise ValueError(f"output should be in format `(y_pred,y)` but got tuple of {len(output)} tensors.")
166+
167+
y_pred, y = output
168+
if y_pred.shape != y.shape:
169+
raise ValueError(f"y_pred and y must be in the same shape, got {y_pred.shape} != {y.shape}.")
170+
171+
# Filter out examples with no relevant items if ignore_zero_hits is True
172+
if self.ignore_zero_hits:
173+
valid_mask = torch.any(y >= self.relevance_threshold, dim=-1)
174+
y_pred = y_pred[valid_mask]
175+
y = y[valid_mask]
176+
177+
if y.shape[0] == 0:
178+
return
179+
180+
# Zero out items below relevance threshold for DCG computation
181+
y_for_dcg = torch.where(y >= self.relevance_threshold, y, torch.zeros_like(y))
182+
183+
max_k = self.top_k[-1]
184+
185+
# Get ranked indices based on predictions (stable=True for deterministic tie-breaking)
186+
ranked_indices = torch.argsort(y_pred, dim=-1, descending=True, stable=True)[:, :max_k]
187+
188+
# Get relevance scores in the predicted ranking order
189+
ranked_relevance = torch.gather(y_for_dcg, dim=-1, index=ranked_indices)
190+
191+
# Compute ideal ranking by sorting true relevance scores
192+
ideal_relevance = torch.sort(y_for_dcg, dim=-1, descending=True, stable=True)[0][:, :max_k]
193+
194+
for i, k in enumerate(self.top_k):
195+
# Compute DCG@k and IDCG@k
196+
dcg_k = self._compute_dcg(ranked_relevance, k)
197+
idcg_k = self._compute_dcg(ideal_relevance, k)
198+
199+
# NDCG = DCG / IDCG, handle division by zero (when IDCG = 0, NDCG = 0)
200+
ndcg_k = torch.where(
201+
idcg_k > 0,
202+
dcg_k / idcg_k,
203+
torch.zeros_like(dcg_k)
204+
)
205+
206+
self._sum_ndcg_per_k[i] += ndcg_k.sum().to(self._device)
207+
208+
self._num_examples += y.shape[0]
209+
210+
@sync_all_reduce("_sum_ndcg_per_k", "_num_examples")
211+
def compute(self) -> list[float]:
212+
if self._num_examples == 0:
213+
raise NotComputableError("NDCG must have at least one example.")
214+
215+
ndcg_scores = (self._sum_ndcg_per_k / self._num_examples).tolist()
216+
return ndcg_scores

0 commit comments

Comments
 (0)