Skip to content

Add get_sequence_transform to Accuracy for seq2seq models#3728

Open
joemunene-by wants to merge 1 commit intopytorch:masterfrom
joemunene-by:feat/sequence-accuracy-transform
Open

Add get_sequence_transform to Accuracy for seq2seq models#3728
joemunene-by wants to merge 1 commit intopytorch:masterfrom
joemunene-by:feat/sequence-accuracy-transform

Conversation

@joemunene-by
Copy link
Copy Markdown

Summary

Adds Accuracy.get_sequence_transform(pad_index, output_transform) — a static method that returns a pre-built output transform for sequence-to-sequence models, addressing the common NLP/Transformer boilerplate discussed in #3593.

Before (manual boilerplate every project):

def accuracy_transform(output):
    preds, y = output
    preds_flat = preds.reshape(-1, preds.size(-1))
    y_flat = y.reshape(-1)
    mask = y_flat != pad_idx
    return preds_flat[mask], y_flat[mask]

acc = Accuracy(output_transform=accuracy_transform)

After (one-liner):

acc = Accuracy(output_transform=Accuracy.get_sequence_transform(pad_index=0))

What it does

  • Flattens (batch, seq_len, num_classes) predictions → (batch * seq_len, num_classes)
  • Flattens (batch, seq_len) targets → (batch * seq_len,)
  • Masks out positions where target == pad_index
  • Supports optional output_transform composition (e.g., extracting from dicts)

Design

Follows @vfdev-5's suggestion in #3593 to use a static method returning a transform rather than creating a new metric class. This keeps the API clean and avoids duplication.

Tests

6 new tests covering:

  • Basic padding mask behavior
  • No padding (all tokens unmasked)
  • End-to-end with Accuracy metric (correct predictions)
  • Partial accuracy with wrong predictions
  • Custom pad index
  • Composition with custom output_transform

All existing tests pass (140 passed, 292 skipped for CUDA/MPS).

Closes #3593

Add a static method `Accuracy.get_sequence_transform(pad_index)` that
returns an output_transform for sequence-to-sequence models. Handles
flattening (batch, seq_len, num_classes) tensors and masking padding
tokens, eliminating the need for custom boilerplate in NLP/Transformer
training loops.

Closes pytorch#3593
@github-actions github-actions bot added the module: metrics Metrics module label Apr 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: metrics Metrics module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]:SequenceAccuracy Metric for PyTorch Ignite

1 participant