Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions ignite/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,84 @@ def thresholded_output_transform(output):

0.6666...

Sequence case. For sequence-to-sequence models (e.g., NLP transformers), use
:meth:`get_sequence_transform` to handle flattening and padding mask:

.. testcode:: 5

metric = Accuracy(output_transform=Accuracy.get_sequence_transform(pad_index=0))
metric.attach(default_evaluator, "accuracy")
# (batch=2, seq_len=4, num_classes=3)
y_pred = torch.tensor([
[[0.1, 0.8, 0.1], [0.9, 0.05, 0.05], [0.2, 0.1, 0.7], [0.0, 0.0, 0.0]],
[[0.7, 0.2, 0.1], [0.1, 0.1, 0.8], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
])
# (batch=2, seq_len=4), 0 = padding
y_true = torch.tensor([
[1, 0, 2, 0],
[0, 2, 0, 0],
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["accuracy"])

.. testoutput:: 5

1.0

.. versionchanged:: 0.5.1
``skip_unrolling`` argument is added.
"""

_state_dict_all_req_keys = ("_num_correct", "_num_examples")

@staticmethod
def get_sequence_transform(
pad_index: int = 0,
output_transform: Callable = lambda x: x,
) -> Callable:
"""Returns an output transform for sequence-to-sequence models.

Handles the common NLP/Transformer pattern where model output is
``(batch, seq_len, num_classes)`` and target is ``(batch, seq_len)``.
The returned transform flattens both tensors along the sequence
dimension and masks out positions where the target equals ``pad_index``.

Args:
pad_index: the target value used for padding. Positions where
``y == pad_index`` are excluded from the accuracy computation.
Default: 0.
output_transform: an optional transform applied to the engine
output before the sequence transform. Default is the identity
function.

Returns:
A callable suitable for the ``output_transform`` argument of
:class:`Accuracy`.

Example::

from ignite.metrics import Accuracy

acc = Accuracy(
output_transform=Accuracy.get_sequence_transform(pad_index=0)
)

.. versionadded:: 0.5.3
"""

def _sequence_transform(output: tuple) -> tuple:
y_pred, y = output_transform(output)

# y_pred: (batch, seq_len, num_classes) -> (batch * seq_len, num_classes)
# y: (batch, seq_len) -> (batch * seq_len,)
y_pred_flat = y_pred.reshape(-1, y_pred.size(-1))
y_flat = y.reshape(-1)

mask = y_flat != pad_index
return y_pred_flat[mask], y_flat[mask]

return _sequence_transform

def __init__(
self,
output_transform: Callable = lambda x: x,
Expand Down
89 changes: 89 additions & 0 deletions tests/ignite/metrics/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,92 @@ def update(self, output):
state = State(output=(y_pred, y_true))
engine = MagicMock(state=state)
acc.iteration_completed(engine)


class TestGetSequenceTransform:
def test_basic_masking(self):
"""Test that padding tokens are correctly masked out."""
transform = Accuracy.get_sequence_transform(pad_index=0)

# (batch=2, seq_len=3, num_classes=4)
y_pred = torch.tensor([
[[0.1, 0.8, 0.05, 0.05], [0.9, 0.05, 0.025, 0.025], [0.0, 0.0, 0.0, 0.0]],
[[0.7, 0.1, 0.1, 0.1], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
])
# (batch=2, seq_len=3), 0 = padding
y = torch.tensor([
[1, 0, 0],
[0, 0, 0],
])

y_pred_out, y_out = transform((y_pred, y))

# Only non-pad positions: (0,0)->1, (0,1)->0, (1,0)->0
# Wait, pad_index=0 means y==0 is masked. So only y==1 survives.
# Position (0,0) has y=1, so only 1 token survives.
assert y_out.shape[0] == 1
assert y_out[0] == 1

def test_all_tokens_unmasked(self):
"""Test with no padding (pad_index not present in targets)."""
transform = Accuracy.get_sequence_transform(pad_index=-1)

y_pred = torch.rand(2, 4, 3)
y = torch.tensor([[0, 1, 2, 1], [2, 0, 1, 0]])

y_pred_out, y_out = transform((y_pred, y))
assert y_out.shape[0] == 8 # 2 * 4, nothing masked

def test_accuracy_with_sequence_transform(self):
"""End-to-end test: Accuracy metric with sequence transform."""
acc = Accuracy(output_transform=Accuracy.get_sequence_transform(pad_index=0))

# (batch=1, seq_len=3, num_classes=3)
y_pred = torch.tensor([[[0.1, 0.8, 0.1], [0.1, 0.1, 0.8], [0.0, 0.0, 0.0]]])
# (batch=1, seq_len=3), token 3 is padding
y = torch.tensor([[1, 2, 0]])

acc.update(acc._output_transform((y_pred, y)))
result = acc.compute()
assert result == 1.0 # Both non-pad predictions are correct

def test_accuracy_with_wrong_predictions(self):
"""Test accuracy with some wrong predictions."""
acc = Accuracy(output_transform=Accuracy.get_sequence_transform(pad_index=-1))

# (batch=1, seq_len=4, num_classes=3)
y_pred = torch.tensor([
[[0.8, 0.1, 0.1], [0.1, 0.8, 0.1], [0.1, 0.1, 0.8], [0.8, 0.1, 0.1]],
])
# (batch=1, seq_len=4)
y = torch.tensor([[0, 1, 0, 0]]) # 3rd token is wrong (pred=2, true=0)

acc.update(acc._output_transform((y_pred, y)))
result = acc.compute()
assert result == pytest.approx(0.75) # 3 out of 4 correct

def test_custom_pad_index(self):
"""Test with a non-zero pad index."""
transform = Accuracy.get_sequence_transform(pad_index=99)

y_pred = torch.rand(2, 5, 10)
y = torch.tensor([[1, 2, 3, 99, 99], [4, 5, 99, 99, 99]])

y_pred_out, y_out = transform((y_pred, y))
assert y_out.shape[0] == 5 # 3 + 2 non-pad tokens

def test_with_output_transform(self):
"""Test composing with a custom output_transform."""
def extract_from_dict(output):
return output["logits"], output["labels"]

transform = Accuracy.get_sequence_transform(
pad_index=0, output_transform=extract_from_dict
)

y_pred = torch.rand(2, 3, 5)
y = torch.tensor([[1, 2, 0], [3, 0, 0]])
output = {"logits": y_pred, "labels": y}

y_pred_out, y_out = transform(output)
assert y_out.shape[0] == 3 # 2 + 1 non-pad tokens