diff --git a/ignite/metrics/cohen_kappa.py b/ignite/metrics/cohen_kappa.py index 610b7890a23b..6d0ae3565f23 100644 --- a/ignite/metrics/cohen_kappa.py +++ b/ignite/metrics/cohen_kappa.py @@ -1,16 +1,160 @@ from collections.abc import Callable +from functools import partial from typing import Literal import torch +import torch.nn.functional as F +from ignite.exceptions import NotComputableError +from ignite.metrics.confusion_matrix import ConfusionMatrix from ignite.metrics.epoch_metric import EpochMetric +from ignite.metrics.metric import Metric, reinit__is_reduced -class CohenKappa(EpochMetric): - """Compute different types of Cohen's Kappa: Non-Wieghted, Linear, Quadratic. - Accumulating predictions and the ground-truth during an epoch and applying - `sklearn.metrics.cohen_kappa_score `_ . +def _kappa_from_conf(conf: torch.Tensor, weights: Literal["linear", "quadratic"] | None) -> float: + n = conf.sum() + if n == 0: + raise NotComputableError("CohenKappa cannot be computed on an empty confusion matrix (n == 0).") + + n_classes = conf.shape[0] + + if weights is None: + p_o = conf.trace() / n + row = conf.sum(dim=1) + col = conf.sum(dim=0) + p_e = (row * col).sum() / (n * n) + else: + idx = torch.arange(n_classes, device=conf.device) + if weights == "linear": + w = torch.abs(idx.unsqueeze(0) - idx.unsqueeze(1)).to(dtype=conf.dtype) + else: + w = ((idx.unsqueeze(0) - idx.unsqueeze(1)) ** 2).to(dtype=conf.dtype) + + w = w / w.max() + p_o = 1 - (w * conf).sum() / n + row = conf.sum(dim=1) + col = conf.sum(dim=0) + expected = row.unsqueeze(1) * col.unsqueeze(0) / n + p_e = 1 - (w * expected).sum() / n + + if (1 - p_e).abs() < 1e-9: + return 1.0 if (p_o - p_e).abs() < 1e-9 else float("nan") + + return ((p_o - p_e) / (1 - p_e)).item() + + +def _cohen_kappa_score( + y_pred: torch.Tensor, + y: torch.Tensor, + weights: Literal["linear", "quadratic"] | None, +) -> float: + num_classes = int(max(y_pred.max().item(), y.max().item())) + 1 + + cm = ConfusionMatrix(num_classes=num_classes, device=y_pred.device) + y_pred_oh = F.one_hot(y_pred.long(), num_classes).float() + cm.update((y_pred_oh, y.long())) + conf = cm.compute().to(dtype=cm._double_dtype) + + return _kappa_from_conf(conf, weights) + + +class _CohenKappaEpochMetric(EpochMetric): + """CohenKappa backed by EpochMetric — infers num_classes dynamically from data.""" + + def __init__( + self, + weights: Literal["linear", "quadratic"] | None, + output_transform: Callable, + device: str | torch.device, + skip_unrolling: bool, + ): + super().__init__( + compute_fn=partial(_cohen_kappa_score, weights=weights), + output_transform=output_transform, + check_compute_fn=False, + device=device, + skip_unrolling=skip_unrolling, + ) + + @reinit__is_reduced + def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: + y_pred, y = output[0].detach(), output[1].detach() + + if y_pred.ndim == 2 and y_pred.shape[1] == 1: + y_pred = y_pred.squeeze(dim=-1) + if y.ndim == 2 and y.shape[1] == 1: + y = y.squeeze(dim=-1) + + if y_pred.ndim > 1 or y.ndim > 1: + raise ValueError("multilabel-indicator is not supported") + + super().update((y_pred, y)) + + def compute(self) -> float: + try: + return super().compute() + except NotComputableError: + raise NotComputableError("CohenKappa must have at least one example before it can be computed.") + + +class _CohenKappaConfusionMatrix(Metric): + """CohenKappa backed by ConfusionMatrix — requires num_classes at construction time. + Accumulates a running confusion matrix; no raw tensor buffering. + """ + + _state_dict_all_req_keys = ("_cm",) + + def __init__( + self, + num_classes: int, + weights: Literal["linear", "quadratic"] | None, + output_transform: Callable, + device: str | torch.device, + skip_unrolling: bool, + ): + self._weights = weights + self._cm = ConfusionMatrix( + num_classes=num_classes, + output_transform=output_transform, + device=device, + skip_unrolling=skip_unrolling, + ) + super().__init__(output_transform=output_transform, device=device, skip_unrolling=skip_unrolling) + + @reinit__is_reduced + def reset(self) -> None: + self._cm.reset() + + @reinit__is_reduced + def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: + y_pred, y = output[0].detach(), output[1].detach() + + if y_pred.ndim == 2 and y_pred.shape[1] == 1: + y_pred = y_pred.squeeze(dim=-1) + if y.ndim == 2 and y.shape[1] == 1: + y = y.squeeze(dim=-1) + + if y_pred.ndim > 1 or y.ndim > 1: + raise ValueError("multilabel-indicator is not supported") + + num_classes = self._cm.num_classes + y_pred_oh = F.one_hot(y_pred.long(), num_classes).float().to(self._device) + self._cm.update((y_pred_oh, y.long().to(self._device))) + + def compute(self) -> float: + if self._cm.confusion_matrix.sum() == 0: + raise NotComputableError("CohenKappa must have at least one example before it can be computed.") + conf = self._cm.compute().to(dtype=self._double_dtype) + return _kappa_from_conf(conf, self._weights) + + +class CohenKappa(Metric): + """Compute different types of Cohen's Kappa: Non-Weighted, Linear, Quadratic. + + When ``num_classes`` is provided, accumulates a running confusion matrix via + :class:`~ignite.metrics.confusion_matrix.ConfusionMatrix` (memory-efficient, no raw tensor buffering). + When ``num_classes`` is ``None`` (default), buffers predictions and targets via + :class:`~ignite.metrics.EpochMetric` and infers the number of classes from data. Args: output_transform: a callable that is used to transform the @@ -19,19 +163,17 @@ class CohenKappa(EpochMetric): you want to compute the metric with respect to one of the outputs. weights: a string is used to define the type of Cohen's Kappa whether Non-Weighted or Linear or Quadratic. Default, None. - check_compute_fn: Default False. If True, `cohen_kappa_score - `_ - is run on the first batch of data to ensure there are - no issues. User will be warned in case there are any issues computing the function. device: optional device specification for internal storage. skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be true for multi-output model, for example, if ``y_pred`` contains multi-output as ``(y_pred_a, y_pred_b)`` Alternatively, ``output_transform`` can be used to handle this. + num_classes: number of classes. If provided, uses a running confusion matrix + (memory-efficient). If ``None``, infers from data at compute time (backward-compatible default). Examples: To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. The output of the engine's ``process_function`` needs to be in the format of - ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added + ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_transform`` can be added to the metric to transform the output into the form expected by the metric. .. include:: defaults.rst @@ -52,37 +194,52 @@ class CohenKappa(EpochMetric): .. versionchanged:: 0.5.1 ``skip_unrolling`` argument is added. + + .. versionchanged:: 0.6.0 + Replaced scikit-learn dependency with a native PyTorch implementation. + + .. versionchanged:: 0.6.2 + Added ``num_classes`` argument; routes to a running-confusion-matrix backend when provided. """ def __init__( self, output_transform: Callable = lambda x: x, weights: Literal["linear", "quadratic"] | None = None, - check_compute_fn: bool = False, device: str | torch.device = torch.device("cpu"), skip_unrolling: bool = False, + num_classes: int | None = None, ): - try: - from sklearn.metrics import cohen_kappa_score # noqa: F401 - except ImportError: - raise ModuleNotFoundError("This contrib module requires scikit-learn to be installed.") if weights not in (None, "linear", "quadratic"): raise ValueError("Kappa Weighting type must be None or linear or quadratic.") - # initialize weights self.weights: Literal["linear", "quadratic"] | None = weights - super().__init__( - self._cohen_kappa_score, - output_transform=output_transform, - check_compute_fn=check_compute_fn, - device=device, - skip_unrolling=skip_unrolling, - ) - - def _cohen_kappa_score(self, y_targets: torch.Tensor, y_preds: torch.Tensor) -> float: - from sklearn.metrics import cohen_kappa_score - - y_true = y_targets.cpu().numpy() - y_pred = y_preds.cpu().numpy() - return cohen_kappa_score(y_true, y_pred, weights=self.weights) + if num_classes is not None: + self._impl: Metric = _CohenKappaConfusionMatrix( + num_classes=num_classes, + weights=weights, + output_transform=output_transform, + device=device, + skip_unrolling=skip_unrolling, + ) + else: + self._impl = _CohenKappaEpochMetric( + weights=weights, + output_transform=output_transform, + device=device, + skip_unrolling=skip_unrolling, + ) + + super().__init__(output_transform=output_transform, device=device, skip_unrolling=skip_unrolling) + + @reinit__is_reduced + def reset(self) -> None: + self._impl.reset() + + @reinit__is_reduced + def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: + self._impl.update(output) + + def compute(self) -> float: + return self._impl.compute() diff --git a/tests/ignite/metrics/test_cohen_kappa.py b/tests/ignite/metrics/test_cohen_kappa.py index 1649c21668e4..e15ae5feda09 100644 --- a/tests/ignite/metrics/test_cohen_kappa.py +++ b/tests/ignite/metrics/test_cohen_kappa.py @@ -1,8 +1,6 @@ import os -from unittest.mock import patch import pytest -import sklearn import torch from sklearn.metrics import cohen_kappa_score @@ -14,22 +12,11 @@ torch.manual_seed(12) -@pytest.fixture() -def mock_no_sklearn(): - with patch.dict("sys.modules", {"sklearn.metrics": None}): - yield sklearn - - -def test_no_sklearn(mock_no_sklearn): - with pytest.raises(ModuleNotFoundError, match=r"This contrib module requires scikit-learn to be installed."): - CohenKappa() - - def test_no_update(): ck = CohenKappa() with pytest.raises( - NotComputableError, match=r"EpochMetric must have at least one example before it can be computed" + NotComputableError, match=r"CohenKappa must have at least one example before it can be computed" ): ck.compute() @@ -37,30 +24,28 @@ def test_no_update(): def test_input_types(): ck = CohenKappa() ck.reset() - output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) + output1 = (torch.rand(10), torch.randint(0, 2, size=(10,), dtype=torch.long)) ck.update(output1) with pytest.raises(ValueError, match=r"Incoherent types between input y_pred and stored predictions"): - ck.update((torch.randint(0, 5, size=(4, 3)), torch.randint(0, 2, size=(4, 3)))) + ck.update((torch.randint(0, 5, size=(10,)), torch.randint(0, 2, size=(10,)))) with pytest.raises(ValueError, match=r"Incoherent types between input y and stored targets"): - ck.update((torch.rand(4, 3), torch.randint(0, 2, size=(4, 3)).to(torch.int32))) - - with pytest.raises(ValueError, match=r"Incoherent types between input y_pred and stored predictions"): - ck.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10, 5)).long())) + ck.update((torch.rand(10), torch.randint(0, 2, size=(10,)).to(torch.int32))) def test_check_shape(): ck = CohenKappa() + impl = ck._impl with pytest.raises(ValueError, match=r"Predictions should be of shape"): - ck._check_shape((torch.tensor(0), torch.tensor(0))) + impl._check_shape((torch.tensor(0), torch.tensor(0))) with pytest.raises(ValueError, match=r"Predictions should be of shape"): - ck._check_shape((torch.rand(4, 3, 1), torch.rand(4, 3))) + impl._check_shape((torch.rand(4, 3, 1), torch.rand(4, 3))) with pytest.raises(ValueError, match=r"Targets should be of shape"): - ck._check_shape((torch.rand(4, 3), torch.rand(4, 3, 1))) + impl._check_shape((torch.rand(4, 3), torch.rand(4, 3, 1))) def test_cohen_kappa_wrong_weights_type(): @@ -329,3 +314,70 @@ def _test_distrib_xla_nprocs(index): def test_distrib_xla_nprocs(xmp_executor): n = int(os.environ["NUM_TPU_WORKERS"]) xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n) + + +# --- num_classes path tests --- + + +def test_num_classes_no_update(): + ck = CohenKappa(num_classes=3) + with pytest.raises( + NotComputableError, match=r"CohenKappa must have at least one example before it can be computed" + ): + ck.compute() + + +@pytest.mark.parametrize("weights", [None, "linear", "quadratic"]) +def test_num_classes_matches_dynamic(weights, available_device): + torch.manual_seed(42) + y_pred = torch.randint(0, 4, size=(60,)).long() + y = torch.randint(0, 4, size=(60,)).long() + batch_size = 10 + + ck_dynamic = CohenKappa(weights=weights, device=available_device) + ck_fixed = CohenKappa(weights=weights, device=available_device, num_classes=4) + + for ck in (ck_dynamic, ck_fixed): + ck.reset() + for i in range(60 // batch_size): + idx = i * batch_size + ck.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + + assert ck_dynamic.compute() == pytest.approx(ck_fixed.compute()) + + +@pytest.mark.parametrize("weights", [None, "linear", "quadratic"]) +def test_num_classes_single_batch(weights, available_device): + torch.manual_seed(0) + y_pred = torch.randint(0, 3, size=(30,)).long() + y = torch.randint(0, 3, size=(30,)).long() + + ck = CohenKappa(weights=weights, device=available_device, num_classes=3) + ck.reset() + ck.update((y_pred, y)) + res = ck.compute() + + assert isinstance(res, float) + assert cohen_kappa_score(y.numpy(), y_pred.numpy(), weights=weights) == pytest.approx(res) + + +def test_num_classes_multilabel_inputs(): + ck = CohenKappa(num_classes=4) + with pytest.raises(ValueError, match=r"multilabel-indicator is not supported"): + ck.reset() + ck.update((torch.randint(0, 2, size=(10, 4)).long(), torch.randint(0, 2, size=(10, 4)).long())) + ck.compute() + + +def test_num_classes_squeeze_n1(): + torch.manual_seed(7) + y_pred = torch.randint(0, 2, size=(20, 1)).long() + y = torch.randint(0, 2, size=(20, 1)).long() + + ck = CohenKappa(num_classes=2) + ck.reset() + ck.update((y_pred, y)) + res = ck.compute() + + assert isinstance(res, float) + assert cohen_kappa_score(y.squeeze().numpy(), y_pred.squeeze().numpy()) == pytest.approx(res)