Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from ignite.metrics.accumulation import Average, GeometricAverage, VariableAccumulation
from ignite.metrics.accuracy import Accuracy
from ignite.metrics.utils import get_sequence_transform
from ignite.metrics.average_precision import AveragePrecision
from ignite.metrics.classification_report import ClassificationReport
from ignite.metrics.cohen_kappa import CohenKappa
Expand Down Expand Up @@ -54,6 +55,7 @@
"Loss",
"MetricGroup",
"MetricsLambda",
"get_sequence_transform",
"MeanAbsoluteError",
"MeanPairwiseDistance",
"MeanSquaredError",
Expand Down
77 changes: 77 additions & 0 deletions ignite/metrics/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
from typing import Callable, Iterable, Sequence


def get_sequence_transform(
ignore_index: int | Iterable[int] | None = None,
output_transform: Callable = lambda x: x,
) -> Callable:
"""
Returns a callable to transform sequence model outputs for metric evaluation.
It flattens the sequences and filters out the padding (`ignore_index`).

Args:
ignore_index: An integer or an iterable of integers representing padding or
special tokens to be masked out from the sequence evaluation.
output_transform: A callable to transform the output into `(y_pred, y)`.

Returns:
Callable that flattens `y_pred` and `y` and removes `ignore_index` elements.
"""
def wrapper(output: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
y_pred, y = output_transform(output)

if y_pred.ndimension() == 3 and y.ndimension() == 2:
if y_pred.shape[:2] == y.shape:
# y_pred is (N, L, C), y is (N, L)
y_pred = y_pred.reshape(-1, y_pred.size(-1))
y = y.reshape(-1)
elif y_pred.shape[0] == y.shape[0] and y_pred.shape[2] == y.shape[1]:
# y_pred is (N, C, L), y is (N, L)
y_pred = y_pred.transpose(1, 2).reshape(-1, y_pred.size(1))
y = y.reshape(-1)
else:
raise ValueError(
f"y_pred and y have incompatible sequence shapes: "
f"y_pred={y_pred.shape} vs y={y.shape}"
)
elif y_pred.ndimension() == 3 and y.ndimension() == 3:
# y_pred is (N, L, C) or (N, C, L), y has the same shape
if y_pred.shape == y.shape:
y_pred = y_pred.reshape(-1)
y = y.reshape(-1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this case the y_pred should be in shape (N,C) this block is turning it to (N*C) and for y argmax should be taken because the Accuracy metric doesn't support one hot encoded labels.

else:
raise ValueError(
f"y_pred and y have incompatible sequence shapes: "
f"y_pred={y_pred.shape} vs y={y.shape}"
)
elif y_pred.ndimension() == 2 and y.ndimension() == 2:
# y_pred is (N, L), y is (N, L)
if y_pred.shape == y.shape:
y_pred = y_pred.reshape(-1)
y = y.reshape(-1)
else:
raise ValueError(
f"y_pred and y have incompatible sequence shapes: "
f"y_pred={y_pred.shape} vs y={y.shape}"
)
else:
raise ValueError(
f"y_pred and y must be 3D/2D, 3D/3D, or 2D/2D arrays "
f"for sequence transformation. Got {y_pred.ndimension()}D and {y.ndimension()}D."
)

if ignore_index is not None:
if isinstance(ignore_index, Iterable):
mask = torch.ones_like(y, dtype=torch.bool)
for idx in ignore_index:
mask &= (y != idx)
else:
mask = y != ignore_index

y_pred = y_pred[mask]
y = y[mask]

return y_pred, y

return wrapper
75 changes: 75 additions & 0 deletions tests/ignite/metrics/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest
import torch

from ignite.metrics.utils import get_sequence_transform

def test_get_sequence_transform():
# test (N, L, C)
y_pred = torch.tensor(
[
[[0.1, 0.9], [0.8, 0.2], [0.3, 0.7], [0.5, 0.5]],
[[0.9, 0.1], [0.2, 0.8], [0.4, 0.6], [0.5, 0.5]],
]
) # shape: (2, 4, 2)
y = torch.tensor([[1, 0, 1, -1], [0, 1, 0, -1]]) # shape: (2, 4)

transform = get_sequence_transform(ignore_index=-1)
y_pred_t, y_t = transform((y_pred, y))

assert y_pred_t.shape == (6, 2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when you apply mask the returned tensor is 1D, can you double check tests and run them.

This test will fail.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my bad its correct

assert y_t.shape == (6,)
assert y_t.tolist() == [1, 0, 1, 0, 1, 0]
assert y_pred_t[:, 1].tolist() == pytest.approx([0.9, 0.2, 0.7, 0.1, 0.8, 0.6])

# test (N, C, L)
y_pred_ncl = y_pred.transpose(1, 2).contiguous() # (2, 2, 4)
y_pred_t2, y_t2 = transform((y_pred_ncl, y))
assert y_pred_t2.shape == (6, 2)
assert torch.all(y_pred_t2 == y_pred_t)
assert torch.all(y_t2 == y_t)

# test binary (N, L)
y_pred_bin = torch.tensor([[1, 0, 1, 1], [0, 1, 0, 0]])
y_bin = torch.tensor([[1, 0, 1, 2], [0, 1, 0, 2]])
transform_bin = get_sequence_transform(ignore_index=2)
y_pred_bin_t, y_bin_t = transform_bin((y_pred_bin, y_bin))

assert y_pred_bin_t.shape == (6,)
assert y_bin_t.shape == (6,)
assert y_bin_t.tolist() == [1, 0, 1, 0, 1, 0]
assert y_pred_bin_t.tolist() == [1, 0, 1, 0, 1, 0]

# test without padding
transform_nopad = get_sequence_transform()
y_pred_nopad, y_nopad = transform_nopad((y_pred_bin, y_bin))
assert y_pred_nopad.shape == (8,)
assert y_nopad.shape == (8,)

# test multiple ignore_index values
y_bin = torch.tensor([[1, -1, 1, 2], [0, 1, -1, 2]])
transform_multi = get_sequence_transform(ignore_index=[-1, 2])
y_pred_multi_t, y_multi_t = transform_multi((y_pred_bin, y_bin))
assert y_pred_multi_t.shape == (4,)
assert y_multi_t.shape == (4,)
assert y_multi_t.tolist() == [1, 1, 0, 1]

# test 3D tensors matched shaping (N, C, L) with (N, C, L)
y_pred_3d = torch.tensor([[[0.1, 0.9], [0.8, 0.2]], [[0.3, 0.7], [0.5, 0.5]]])
y_3d = torch.tensor([[[1, 0], [1, 1]], [[0, 1], [0, 0]]])
transform_3d = get_sequence_transform()
y_pred_3d_t, y_3d_t = transform_3d((y_pred_3d, y_3d))

assert y_pred_3d_t.shape == (8,)
assert y_3d_t.shape == (8,)
assert y_pred_3d_t.tolist() == pytest.approx([0.1, 0.9, 0.8, 0.2, 0.3, 0.7, 0.5, 0.5])
assert y_3d_t.tolist() == [1, 0, 1, 1, 0, 1, 0, 0]

# test bad shapes
y_bad = torch.tensor([1, 0, 1])
with pytest.raises(ValueError, match="must be 3D/2D, 3D/3D, or 2D/2D arrays"):
transform((y_pred_bin, y_bad))

y_pred_bad = torch.tensor([[[1], [2]], [[3], [4]]])
y_bad = torch.tensor([[1, 2, 3], [4, 5, 6]])
with pytest.raises(ValueError, match="incompatible sequence shapes"):
transform((y_pred_bad, y_bad))