diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index be33b268b15e..db2669fe32b7 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -23,7 +23,17 @@ def __init__( is_multilabel: bool = False, device: str | torch.device = torch.device("cpu"), skip_unrolling: bool = False, + class_names: list[str] | None = None, ): + if class_names is not None: + if not isinstance(class_names, (list, tuple)) or not all(isinstance(n, str) for n in class_names): + raise ValueError("class_names must be a list of strings") + if average is not False and average is not None: + raise ValueError( + f"class_names is only applicable when average=False or average=None, got average={average!r}." + ) + self._class_names = class_names + if not (average is None or isinstance(average, bool) or average in ["macro", "micro", "weighted", "samples"]): raise ValueError( "Argument average should be None or a boolean or one of values" @@ -157,6 +167,8 @@ def compute(self) -> torch.Tensor | float: elif self._average == "macro": return cast(torch.Tensor, fraction).mean().item() else: + if self._class_names is not None: + return dict(zip(self._class_names, fraction.tolist())) return fraction @@ -246,6 +258,10 @@ class Precision(_BasePrecisionRecall): skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be true for multi-output model, for example, if ``y_pred`` contains multi-output as ``(y_pred_a, y_pred_b)`` Alternatively, ``output_transform`` can be used to handle this. + class_names: list of class name strings used to label per-class output when ``average=False`` + or ``average=None``. If provided, :meth:`compute` returns a ``dict`` mapping each class + name to its metric value instead of a tensor. Must match the number of classes inferred + from the data. Default: ``None``. Examples: @@ -428,5 +444,9 @@ def update(self, output: Sequence[torch.Tensor]) -> None: if self._average == "weighted": self._weight += y.sum(dim=0) - + if self._class_names is not None and len(self._class_names) != self._numerator.shape[0]: + raise ValueError( + f"class_names has {len(self._class_names)} entries but the metric computed " + f"{self._numerator.shape[0]} classes." + ) self._updated = True diff --git a/ignite/metrics/recall.py b/ignite/metrics/recall.py index 34da36ced48f..d4d12e6fc39c 100644 --- a/ignite/metrics/recall.py +++ b/ignite/metrics/recall.py @@ -97,7 +97,10 @@ class Recall(_BasePrecisionRecall): skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be true for multi-output model, for example, if ``y_pred`` contains multi-output as ``(y_pred_a, y_pred_b)`` Alternatively, ``output_transform`` can be used to handle this. - + class_names: list of class name strings used to label per-class output when ``average=False`` + or ``average=None``. If provided, :meth:`compute` returns a ``dict`` mapping each class + name to its metric value instead of a tensor. Must match the number of classes inferred + from the data. Default: ``None``. Examples: For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. @@ -240,5 +243,10 @@ def update(self, output: Sequence[torch.Tensor]) -> None: if self._average == "weighted": self._weight += y.sum(dim=0) + if self._class_names is not None and len(self._class_names) != self._numerator.shape[0]: + raise ValueError( + f"class_names has {len(self._class_names)} entries but the metric computed " + f"{self._numerator.shape[0]} classes." + ) self._updated = True diff --git a/tests/ignite/metrics/test_precision.py b/tests/ignite/metrics/test_precision.py index 7bd80540f5de..d2121858c6a3 100644 --- a/tests/ignite/metrics/test_precision.py +++ b/tests/ignite/metrics/test_precision.py @@ -325,6 +325,86 @@ def test_incorrect_y_classes(average): assert pr._updated is False +@pytest.mark.parametrize( + "invalid_class_names", + [ + [0, 1, 2], # list of ints + "cat", # string instead of list + [0.1, 0.2], # list of floats + ["cat", 1], # mixed + ], +) +def test_class_names_invalid_type(invalid_class_names): + with pytest.raises(ValueError, match="class_names must be a list of strings"): + Precision(average=False, class_names=invalid_class_names) + + +@pytest.mark.parametrize("average", ["macro", "micro", "weighted", "samples", True]) +def test_class_names_incompatible_average(average): + with pytest.raises(ValueError, match="class_names is only applicable when average=False or average=None"): + Precision(average=average, class_names=["cat", "dog", "horse"]) + + +@pytest.mark.parametrize("average", [False, None]) +def test_class_names_multiclass(average): + class_names = ["cat", "dog", "horse"] + pr = Precision(average=average, class_names=class_names) + + y_pred = torch.tensor( + [ + [0.0266, 0.1719, 0.3055], + [0.6886, 0.3978, 0.8176], + [0.9230, 0.0197, 0.8395], + [0.1785, 0.2670, 0.6084], + [0.8448, 0.7177, 0.7288], + ] + ) + y = torch.tensor([2, 0, 2, 1, 0]) + + pr.update((y_pred, y)) + result = pr.compute() + + assert isinstance(result, dict) + assert list(result.keys()) == class_names + assert result == pytest.approx({"cat": 0.5, "dog": 0.0, "horse": 0.3333333333333333}) + + +def test_class_names_length_mismatch(): + pr = Precision(average=False, class_names=["cat", "dog"]) + + y_pred = torch.tensor( + [ + [0.0266, 0.1719, 0.3055], + [0.6886, 0.3978, 0.8176], + [0.9230, 0.0197, 0.8395], + ] + ) + y = torch.tensor([2, 0, 1]) + + with pytest.raises(ValueError, match="class_names has 2 entries but the metric computed 3 classes"): + pr.update((y_pred, y)) + + +def test_class_names_none_returns_tensor(): + pr = Precision(average=False) + + y_pred = torch.tensor( + [ + [0.0266, 0.1719, 0.3055], + [0.6886, 0.3978, 0.8176], + [0.9230, 0.0197, 0.8395], + [0.1785, 0.2670, 0.6084], + [0.8448, 0.7177, 0.7288], + ] + ) + y = torch.tensor([2, 0, 2, 1, 0]) + + pr.update((y_pred, y)) + result = pr.compute() + + assert isinstance(result, torch.Tensor) + + @pytest.mark.usefixtures("distributed") class TestDistributed: @pytest.mark.parametrize("average", [False, "macro", "weighted", "micro"])