Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from ignite.metrics.accumulation import Average, GeometricAverage, VariableAccumulation
from ignite.metrics.accuracy import Accuracy
from ignite.metrics.utils import get_sequence_transform
from ignite.metrics.average_precision import AveragePrecision
from ignite.metrics.classification_report import ClassificationReport
from ignite.metrics.cohen_kappa import CohenKappa
Expand Down Expand Up @@ -54,6 +55,7 @@
"Loss",
"MetricGroup",
"MetricsLambda",
"get_sequence_transform",
"MeanAbsoluteError",
"MeanPairwiseDistance",
"MeanSquaredError",
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,5 @@ def compute(self) -> float:
if self._num_examples == 0:
raise NotComputableError("Accuracy must have at least one example before it can be computed.")
return self._num_correct.item() / self._num_examples

Comment thread
vfdev-5 marked this conversation as resolved.
Outdated

67 changes: 67 additions & 0 deletions ignite/metrics/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
from typing import Callable, Iterable, Sequence, Union


def get_sequence_transform(
ignore_index: Union[int, Iterable[int], None] = None,
Comment thread
vfdev-5 marked this conversation as resolved.
Outdated
output_transform: Callable = lambda x: x,
) -> Callable:
"""
Returns a callable to transform sequence model outputs for metric evaluation.
It flattens the sequences and filters out the padding (`ignore_index`).

Args:
ignore_index: An integer or an iterable of integers representing padding or
special tokens to be masked out from the sequence evaluation.
output_transform: A callable to transform the output into `(y_pred, y)`.

Returns:
Callable that flattens `y_pred` and `y` and removes `ignore_index` elements.
"""
def wrapper(output: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
y_pred, y = output_transform(output)

if y_pred.ndimension() == 3 and y.ndimension() == 2:
if y_pred.shape[:2] == y.shape:
# y_pred is (N, L, C), y is (N, L)
y_pred = y_pred.reshape(-1, y_pred.size(-1))
y = y.reshape(-1)
elif y_pred.shape[0] == y.shape[0] and y_pred.shape[2] == y.shape[1]:
# y_pred is (N, C, L), y is (N, L)
y_pred = y_pred.transpose(1, 2).reshape(-1, y_pred.size(1))
y = y.reshape(-1)
else:
raise ValueError(
f"y_pred and y have incompatible sequence shapes: "
f"y_pred={y_pred.shape} vs y={y.shape}"
)
elif y_pred.ndimension() == 2 and y.ndimension() == 2:
# y_pred is (N, L), y is (N, L)
if y_pred.shape == y.shape:
y_pred = y_pred.reshape(-1)
y = y.reshape(-1)
else:
raise ValueError(
f"y_pred and y have incompatible sequence shapes: "
f"y_pred={y_pred.shape} vs y={y.shape}"
)
else:
raise ValueError(
f"y_pred and y must be 3D and 2D arrays, or both 2D arrays "
f"for sequence transformation. Got {y_pred.ndimension()}D and {y.ndimension()}D."
)

if ignore_index is not None:
if isinstance(ignore_index, Iterable):
mask = torch.ones_like(y, dtype=torch.bool)
for idx in ignore_index:
mask &= (y != idx)
else:
mask = y != ignore_index

y_pred = y_pred[mask]
y = y[mask]

return y_pred, y

return wrapper
2 changes: 2 additions & 0 deletions tests/ignite/metrics/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,5 @@ def update(self, output):
state = State(output=(y_pred, y_true))
engine = MagicMock(state=state)
acc.iteration_completed(engine)

Comment thread
vfdev-5 marked this conversation as resolved.

64 changes: 64 additions & 0 deletions tests/ignite/metrics/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest
import torch

from ignite.metrics.utils import get_sequence_transform

def test_get_sequence_transform():
# test (N, L, C)
y_pred = torch.tensor(
[
[[0.1, 0.9], [0.8, 0.2], [0.3, 0.7], [0.5, 0.5]],
[[0.9, 0.1], [0.2, 0.8], [0.4, 0.6], [0.5, 0.5]],
]
) # shape: (2, 4, 2)
y = torch.tensor([[1, 0, 1, -1], [0, 1, 0, -1]]) # shape: (2, 4)

transform = get_sequence_transform(ignore_index=-1)
y_pred_t, y_t = transform((y_pred, y))

assert y_pred_t.shape == (6, 2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when you apply mask the returned tensor is 1D, can you double check tests and run them.

This test will fail.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my bad its correct

assert y_t.shape == (6,)
assert y_t.tolist() == [1, 0, 1, 0, 1, 0]
assert y_pred_t[:, 1].tolist() == pytest.approx([0.9, 0.2, 0.7, 0.1, 0.8, 0.6])

# test (N, C, L)
y_pred_ncl = y_pred.transpose(1, 2).contiguous() # (2, 2, 4)
y_pred_t2, y_t2 = transform((y_pred_ncl, y))
assert y_pred_t2.shape == (6, 2)
assert torch.all(y_pred_t2 == y_pred_t)
assert torch.all(y_t2 == y_t)

# test binary (N, L)
y_pred_bin = torch.tensor([[1, 0, 1, 1], [0, 1, 0, 0]])
y_bin = torch.tensor([[1, 0, 1, 2], [0, 1, 0, 2]])
transform_bin = get_sequence_transform(ignore_index=2)
y_pred_bin_t, y_bin_t = transform_bin((y_pred_bin, y_bin))

assert y_pred_bin_t.shape == (6,)
assert y_bin_t.shape == (6,)
assert y_bin_t.tolist() == [1, 0, 1, 0, 1, 0]
assert y_pred_bin_t.tolist() == [1, 0, 1, 0, 1, 0]

# test without padding
transform_nopad = get_sequence_transform()
y_pred_nopad, y_nopad = transform_nopad((y_pred_bin, y_bin))
assert y_pred_nopad.shape == (8,)
assert y_nopad.shape == (8,)

# test multiple ignore_index values
y_bin = torch.tensor([[1, -1, 1, 2], [0, 1, -1, 2]])
transform_multi = get_sequence_transform(ignore_index=[-1, 2])
y_pred_multi_t, y_multi_t = transform_multi((y_pred_bin, y_bin))
assert y_pred_multi_t.shape == (4,)
assert y_multi_t.shape == (4,)
assert y_multi_t.tolist() == [1, 1, 0, 1]

# test bad shapes
y_bad = torch.tensor([1, 0, 1])
with pytest.raises(ValueError, match="must be 3D and 2D arrays, or both 2D arrays"):
transform((y_pred_bin, y_bad))

y_pred_bad = torch.tensor([[[1], [2]], [[3], [4]]])
y_bad = torch.tensor([[1, 2, 3], [4, 5, 6]])
with pytest.raises(ValueError, match="incompatible sequence shapes"):
transform((y_pred_bad, y_bad))