Skip to content

Commit 3716c6a

Browse files
committed
implement registry based topk wrapper
1 parent 82bb998 commit 3716c6a

File tree

2 files changed

+88
-2
lines changed

2 files changed

+88
-2
lines changed

ignite/metrics/precision.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens
8888

8989
return y_pred, y, correct
9090

91+
@classmethod
92+
def _topk_transform(cls, output: Sequence[torch.tensor], top_k: int) -> Sequence[torch.tensor]:
93+
return output
94+
9195
@reinit__is_reduced
9296
def reset(self) -> None:
9397
"""
@@ -410,8 +414,9 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
410414
For binary and multilabel data, both y and y_pred should consist of 0's and 1's, but for multiclass
411415
data, y_pred and y should consist of probabilities and integers respectively.
412416
"""
413-
self._check_shape(output)
414-
self._check_type(output)
417+
if not getattr(self, "_skip_checks", False):
418+
self._check_shape(output)
419+
self._check_type(output)
415420
y_pred, y, correct = self._prepare_output(output)
416421

417422
if self._average == "samples":

ignite/metrics/top_k.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch
2+
from typing import Sequence
3+
4+
from ignite.metrics import Metric
5+
from ignite.metrics.precision import _BasePrecisionRecall
6+
7+
8+
class TopK(Metric):
9+
"""https://github.com/open-mmlab/mmengine/blob/main/mmengine/registry/registry.py
10+
https://github.com/open-mmlab/mmdetection/tree/main
11+
12+
the idea is to maintain tok_k transforms here for metrics.
13+
and user will have to pass output_transform to TopK instead of base_metric
14+
and the output transform must only unpack output into y_pred, y, and do no other transformations like binarisation
15+
"""
16+
17+
_output_transform_registry = {}
18+
19+
@classmethod
20+
def register(cls, metric_type, k_transform):
21+
cls._output_transform_registry[metric_type] = k_transform
22+
23+
def __init__(
24+
self,
25+
base_metric: Metric,
26+
top_k: int | list[int],
27+
output_transform=lambda x: x,
28+
device: str | torch.device = torch.device("cpu"),
29+
skip_unrolling: bool = False,
30+
):
31+
transform = None
32+
for metric_type, k_transform in self._output_transform_registry.items():
33+
if isinstance(base_metric, metric_type):
34+
transform = k_transform
35+
36+
if transform is None:
37+
raise ValueError(f"{type(base_metric).__name__} does not support TopK.")
38+
39+
self._transform = transform
40+
self._base_metric = base_metric
41+
self._ks = sorted(top_k) if isinstance(top_k, list) else [top_k]
42+
super().__init__(output_transform=output_transform, device=device, skip_unrolling=skip_unrolling)
43+
44+
def reset(self):
45+
self._base_metric.reset()
46+
self._states = {k: self._base_metric.state_dict() for k in self._ks}
47+
48+
def update(self, output):
49+
self._base_metric._check_shape(output)
50+
self._base_metric._check_type(output)
51+
self._base_metric._skip_checks = True
52+
53+
for k in self._ks:
54+
# restore state for this k
55+
self._base_metric.load_state_dict(self._states[k])
56+
57+
k_output = self._transform(output, k)
58+
self._base_metric.update(k_output)
59+
60+
# save state for this k
61+
self._states[k] = self._base_metric.state_dict()
62+
63+
self._base_metric._skip_checks = False
64+
65+
def compute(self) -> list:
66+
results = []
67+
for k in self._ks:
68+
self._base_metric.load_state_dict(self._states[k])
69+
results.append(self._base_metric.compute())
70+
return results
71+
72+
73+
def _precision_recall_topk_transform(output: Sequence[torch.Tensor], k: int):
74+
y_pred, y = output[0], output[1]
75+
_, top_indices = torch.topk(y_pred, k=k, dim=-1)
76+
masked = torch.zeros_like(y_pred)
77+
masked.scatter_(-1, top_indices, 1.0)
78+
return (masked, y)
79+
80+
81+
TopK.register(_BasePrecisionRecall, _precision_recall_topk_transform)

0 commit comments

Comments
 (0)