Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions ignite/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,17 @@ def __init__(
is_multilabel: bool = False,
device: str | torch.device = torch.device("cpu"),
skip_unrolling: bool = False,
class_names: list[str] | None = None,
):
if class_names is not None:
if not isinstance(class_names, (list, tuple)) or not all(isinstance(n, str) for n in class_names):
raise ValueError("class_names must be a list of strings")
if average is not False and average is not None:
raise ValueError(
f"class_names is only applicable when average = False or None,but got average={average!r}"
)
self._class_names = class_names

if not (average is None or isinstance(average, bool) or average in ["macro", "micro", "weighted", "samples"]):
raise ValueError(
"Argument average should be None or a boolean or one of values"
Expand Down Expand Up @@ -157,6 +167,13 @@ def compute(self) -> torch.Tensor | float:
elif self._average == "macro":
return cast(torch.Tensor, fraction).mean().item()
else:
if self._class_names is not None:
if len(self._class_names) != fraction.shape[0]:
raise ValueError(
f"class_names has {len(self._class_names)} entries but the metric computed "
f"{fraction.shape[0]} classes."
)
Comment thread
vfdev-5 marked this conversation as resolved.
Outdated
return dict(zip(self._class_names, fraction.tolist()))
return fraction


Expand Down Expand Up @@ -246,6 +263,10 @@ class Precision(_BasePrecisionRecall):
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-output as ``(y_pred_a, y_pred_b)``
Alternatively, ``output_transform`` can be used to handle this.
class_names: list of class name strings used to label per-class output when ``average=False``
or ``average=None``. If provided, :meth:`compute` returns a ``dict`` mapping each class
name to its metric value instead of a tensor. Must match the number of classes inferred
from the data. Default: ``None``.

Examples:

Expand Down
81 changes: 81 additions & 0 deletions tests/ignite/metrics/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,87 @@ def test_incorrect_y_classes(average):
assert pr._updated is False


@pytest.mark.parametrize(
"invalid_class_names",
[
[0, 1, 2], # list of ints
"cat", # string instead of list
[0.1, 0.2], # list of floats
["cat", 1], # mixed
],
)
def test_class_names_invalid_type(invalid_class_names):
with pytest.raises(ValueError, match="class_names must be a list of strings"):
Precision(average=False, class_names=invalid_class_names)


@pytest.mark.parametrize("average", ["macro", "micro", "weighted", "samples", True])
def test_class_names_incompatible_average(average):
with pytest.raises(ValueError, match="class_names is only applicable when average=False or average=None"):
Precision(average=average, class_names=["cat", "dog", "horse"])


@pytest.mark.parametrize("average", [False, None])
def test_class_names_multiclass(average):
class_names = ["cat", "dog", "horse"]
pr = Precision(average=average, class_names=class_names)

y_pred = torch.tensor(
[
[0.0266, 0.1719, 0.3055],
[0.6886, 0.3978, 0.8176],
[0.9230, 0.0197, 0.8395],
[0.1785, 0.2670, 0.6084],
[0.8448, 0.7177, 0.7288],
]
)
y = torch.tensor([2, 0, 2, 1, 0])

pr.update((y_pred, y))
result = pr.compute()

assert isinstance(result, dict)
assert list(result.keys()) == class_names
assert result == pytest.approx({"cat": 0.5, "dog": 0.0, "horse": 0.3333333333333333})


def test_class_names_length_mismatch():
pr = Precision(average=False, class_names=["cat", "dog"])

y_pred = torch.tensor(
[
[0.0266, 0.1719, 0.3055],
[0.6886, 0.3978, 0.8176],
[0.9230, 0.0197, 0.8395],
]
)
y = torch.tensor([2, 0, 1])

pr.update((y_pred, y))
with pytest.raises(ValueError, match="class_names has 2 entries but the metric computed 3 classes"):
pr.compute()


def test_class_names_none_returns_tensor():
pr = Precision(average=False)

y_pred = torch.tensor(
[
[0.0266, 0.1719, 0.3055],
[0.6886, 0.3978, 0.8176],
[0.9230, 0.0197, 0.8395],
[0.1785, 0.2670, 0.6084],
[0.8448, 0.7177, 0.7288],
]
)
y = torch.tensor([2, 0, 2, 1, 0])

pr.update((y_pred, y))
result = pr.compute()

assert isinstance(result, torch.Tensor)


@pytest.mark.usefixtures("distributed")
class TestDistributed:
@pytest.mark.parametrize("average", [False, "macro", "weighted", "micro"])
Expand Down