From 4be793e72c1522a9dd4c5989fa6886a07d5e4f1a Mon Sep 17 00:00:00 2001 From: Joe Munene Date: Fri, 10 Apr 2026 03:02:05 +0300 Subject: [PATCH] feat(metrics): add get_sequence_transform to Accuracy Add a static method `Accuracy.get_sequence_transform(pad_index)` that returns an output_transform for sequence-to-sequence models. Handles flattening (batch, seq_len, num_classes) tensors and masking padding tokens, eliminating the need for custom boilerplate in NLP/Transformer training loops. Closes #3593 --- ignite/metrics/accuracy.py | 72 ++++++++++++++++++++++ tests/ignite/metrics/test_accuracy.py | 89 +++++++++++++++++++++++++++ 2 files changed, 161 insertions(+) diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index 984218559a3c..30fdd411a07c 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -211,12 +211,84 @@ def thresholded_output_transform(output): 0.6666... + Sequence case. For sequence-to-sequence models (e.g., NLP transformers), use + :meth:`get_sequence_transform` to handle flattening and padding mask: + + .. testcode:: 5 + + metric = Accuracy(output_transform=Accuracy.get_sequence_transform(pad_index=0)) + metric.attach(default_evaluator, "accuracy") + # (batch=2, seq_len=4, num_classes=3) + y_pred = torch.tensor([ + [[0.1, 0.8, 0.1], [0.9, 0.05, 0.05], [0.2, 0.1, 0.7], [0.0, 0.0, 0.0]], + [[0.7, 0.2, 0.1], [0.1, 0.1, 0.8], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ]) + # (batch=2, seq_len=4), 0 = padding + y_true = torch.tensor([ + [1, 0, 2, 0], + [0, 2, 0, 0], + ]) + state = default_evaluator.run([[y_pred, y_true]]) + print(state.metrics["accuracy"]) + + .. testoutput:: 5 + + 1.0 + .. versionchanged:: 0.5.1 ``skip_unrolling`` argument is added. """ _state_dict_all_req_keys = ("_num_correct", "_num_examples") + @staticmethod + def get_sequence_transform( + pad_index: int = 0, + output_transform: Callable = lambda x: x, + ) -> Callable: + """Returns an output transform for sequence-to-sequence models. + + Handles the common NLP/Transformer pattern where model output is + ``(batch, seq_len, num_classes)`` and target is ``(batch, seq_len)``. + The returned transform flattens both tensors along the sequence + dimension and masks out positions where the target equals ``pad_index``. + + Args: + pad_index: the target value used for padding. Positions where + ``y == pad_index`` are excluded from the accuracy computation. + Default: 0. + output_transform: an optional transform applied to the engine + output before the sequence transform. Default is the identity + function. + + Returns: + A callable suitable for the ``output_transform`` argument of + :class:`Accuracy`. + + Example:: + + from ignite.metrics import Accuracy + + acc = Accuracy( + output_transform=Accuracy.get_sequence_transform(pad_index=0) + ) + + .. versionadded:: 0.5.3 + """ + + def _sequence_transform(output: tuple) -> tuple: + y_pred, y = output_transform(output) + + # y_pred: (batch, seq_len, num_classes) -> (batch * seq_len, num_classes) + # y: (batch, seq_len) -> (batch * seq_len,) + y_pred_flat = y_pred.reshape(-1, y_pred.size(-1)) + y_flat = y.reshape(-1) + + mask = y_flat != pad_index + return y_pred_flat[mask], y_flat[mask] + + return _sequence_transform + def __init__( self, output_transform: Callable = lambda x: x, diff --git a/tests/ignite/metrics/test_accuracy.py b/tests/ignite/metrics/test_accuracy.py index 668a40b27c8a..8ea3bd053b82 100644 --- a/tests/ignite/metrics/test_accuracy.py +++ b/tests/ignite/metrics/test_accuracy.py @@ -498,3 +498,92 @@ def update(self, output): state = State(output=(y_pred, y_true)) engine = MagicMock(state=state) acc.iteration_completed(engine) + + +class TestGetSequenceTransform: + def test_basic_masking(self): + """Test that padding tokens are correctly masked out.""" + transform = Accuracy.get_sequence_transform(pad_index=0) + + # (batch=2, seq_len=3, num_classes=4) + y_pred = torch.tensor([ + [[0.1, 0.8, 0.05, 0.05], [0.9, 0.05, 0.025, 0.025], [0.0, 0.0, 0.0, 0.0]], + [[0.7, 0.1, 0.1, 0.1], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + ]) + # (batch=2, seq_len=3), 0 = padding + y = torch.tensor([ + [1, 0, 0], + [0, 0, 0], + ]) + + y_pred_out, y_out = transform((y_pred, y)) + + # Only non-pad positions: (0,0)->1, (0,1)->0, (1,0)->0 + # Wait, pad_index=0 means y==0 is masked. So only y==1 survives. + # Position (0,0) has y=1, so only 1 token survives. + assert y_out.shape[0] == 1 + assert y_out[0] == 1 + + def test_all_tokens_unmasked(self): + """Test with no padding (pad_index not present in targets).""" + transform = Accuracy.get_sequence_transform(pad_index=-1) + + y_pred = torch.rand(2, 4, 3) + y = torch.tensor([[0, 1, 2, 1], [2, 0, 1, 0]]) + + y_pred_out, y_out = transform((y_pred, y)) + assert y_out.shape[0] == 8 # 2 * 4, nothing masked + + def test_accuracy_with_sequence_transform(self): + """End-to-end test: Accuracy metric with sequence transform.""" + acc = Accuracy(output_transform=Accuracy.get_sequence_transform(pad_index=0)) + + # (batch=1, seq_len=3, num_classes=3) + y_pred = torch.tensor([[[0.1, 0.8, 0.1], [0.1, 0.1, 0.8], [0.0, 0.0, 0.0]]]) + # (batch=1, seq_len=3), token 3 is padding + y = torch.tensor([[1, 2, 0]]) + + acc.update(acc._output_transform((y_pred, y))) + result = acc.compute() + assert result == 1.0 # Both non-pad predictions are correct + + def test_accuracy_with_wrong_predictions(self): + """Test accuracy with some wrong predictions.""" + acc = Accuracy(output_transform=Accuracy.get_sequence_transform(pad_index=-1)) + + # (batch=1, seq_len=4, num_classes=3) + y_pred = torch.tensor([ + [[0.8, 0.1, 0.1], [0.1, 0.8, 0.1], [0.1, 0.1, 0.8], [0.8, 0.1, 0.1]], + ]) + # (batch=1, seq_len=4) + y = torch.tensor([[0, 1, 0, 0]]) # 3rd token is wrong (pred=2, true=0) + + acc.update(acc._output_transform((y_pred, y))) + result = acc.compute() + assert result == pytest.approx(0.75) # 3 out of 4 correct + + def test_custom_pad_index(self): + """Test with a non-zero pad index.""" + transform = Accuracy.get_sequence_transform(pad_index=99) + + y_pred = torch.rand(2, 5, 10) + y = torch.tensor([[1, 2, 3, 99, 99], [4, 5, 99, 99, 99]]) + + y_pred_out, y_out = transform((y_pred, y)) + assert y_out.shape[0] == 5 # 3 + 2 non-pad tokens + + def test_with_output_transform(self): + """Test composing with a custom output_transform.""" + def extract_from_dict(output): + return output["logits"], output["labels"] + + transform = Accuracy.get_sequence_transform( + pad_index=0, output_transform=extract_from_dict + ) + + y_pred = torch.rand(2, 3, 5) + y = torch.tensor([[1, 2, 0], [3, 0, 0]]) + output = {"logits": y_pred, "labels": y} + + y_pred_out, y_out = transform(output) + assert y_out.shape[0] == 3 # 2 + 1 non-pad tokens