Skip to content
Open
218 changes: 188 additions & 30 deletions ignite/metrics/cohen_kappa.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,161 @@
from collections.abc import Callable
from functools import partial
from typing import Literal

import torch
import torch.nn.functional as F

from ignite.exceptions import NotComputableError
from ignite.metrics.confusion_matrix import ConfusionMatrix
from ignite.metrics.epoch_metric import EpochMetric
from ignite.metrics.metric import Metric, reinit__is_reduced


class CohenKappa(EpochMetric):
"""Compute different types of Cohen's Kappa: Non-Wieghted, Linear, Quadratic.
Accumulating predictions and the ground-truth during an epoch and applying
`sklearn.metrics.cohen_kappa_score <https://scikit-learn.org/stable/modules/
generated/sklearn.metrics.cohen_kappa_score.html>`_ .
def _kappa_from_conf(conf: torch.Tensor, weights: Literal["linear", "quadratic"] | None) -> float:
n = conf.sum()
if n == 0:
raise NotComputableError("CohenKappa cannot be computed on an empty confusion matrix (n == 0).")

n_classes = conf.shape[0]

if weights is None:
p_o = conf.trace() / n
row = conf.sum(dim=1)
col = conf.sum(dim=0)
p_e = (row * col).sum() / (n * n)
else:
idx = torch.arange(n_classes, device=conf.device)
if weights == "linear":
w = torch.abs(idx.unsqueeze(0) - idx.unsqueeze(1)).to(dtype=conf.dtype)
else:
w = ((idx.unsqueeze(0) - idx.unsqueeze(1)) ** 2).to(dtype=conf.dtype)

w = w / w.max()
p_o = 1 - (w * conf).sum() / n
row = conf.sum(dim=1)
col = conf.sum(dim=0)
expected = row.unsqueeze(1) * col.unsqueeze(0) / n
p_e = 1 - (w * expected).sum() / n

if (1 - p_e).abs() < 1e-9:
return 1.0 if (p_o - p_e).abs() < 1e-9 else float("nan")

return ((p_o - p_e) / (1 - p_e)).item()


def _cohen_kappa_score(
y_pred: torch.Tensor,
y: torch.Tensor,
weights: Literal["linear", "quadratic"] | None,
) -> float:
num_classes = int(max(y_pred.max().item(), y.max().item())) + 1

cm = ConfusionMatrix(num_classes=num_classes, device=y_pred.device)
y_pred_oh = F.one_hot(y_pred.long(), num_classes).float()
cm.update((y_pred_oh, y.long()))
double_dtype = torch.float32 if y_pred.device.type == "mps" else torch.float64
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we fetch double dtype from cm._double_dtype?

conf = cm.compute().to(dtype=double_dtype)

return _kappa_from_conf(conf, weights)


class _CohenKappaEpochMetric(EpochMetric):
"""CohenKappa backed by EpochMetric — infers num_classes dynamically from data."""

def __init__(
self,
weights: Literal["linear", "quadratic"] | None,
output_transform: Callable,
device: str | torch.device,
skip_unrolling: bool,
):
super().__init__(
compute_fn=partial(_cohen_kappa_score, weights=weights),
output_transform=output_transform,
check_compute_fn=False,
device=device,
skip_unrolling=skip_unrolling,
)

@reinit__is_reduced
def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()

if y_pred.ndim == 2 and y_pred.shape[1] == 1:
y_pred = y_pred.squeeze(dim=-1)
if y.ndim == 2 and y.shape[1] == 1:
y = y.squeeze(dim=-1)

if y_pred.ndim > 1 or y.ndim > 1:
raise ValueError("multilabel-indicator is not supported")

super().update((y_pred, y))

def compute(self) -> float:
try:
return super().compute()
except NotComputableError:
raise NotComputableError("CohenKappa must have at least one example before it can be computed.")


class _CohenKappaConfusionMatrix(Metric):
"""CohenKappa backed by ConfusionMatrix — requires num_classes at construction time.
Accumulates a running confusion matrix; no raw tensor buffering.
"""

_state_dict_all_req_keys = ("_cm",)

def __init__(
self,
num_classes: int,
weights: Literal["linear", "quadratic"] | None,
output_transform: Callable,
device: str | torch.device,
skip_unrolling: bool,
):
self._weights = weights
self._cm = ConfusionMatrix(
num_classes=num_classes,
output_transform=output_transform,
device=device,
skip_unrolling=skip_unrolling,
)
super().__init__(output_transform=output_transform, device=device, skip_unrolling=skip_unrolling)

@reinit__is_reduced
def reset(self) -> None:
self._cm.reset()

@reinit__is_reduced
def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()

if y_pred.ndim == 2 and y_pred.shape[1] == 1:
y_pred = y_pred.squeeze(dim=-1)
if y.ndim == 2 and y.shape[1] == 1:
y = y.squeeze(dim=-1)

if y_pred.ndim > 1 or y.ndim > 1:
raise ValueError("multilabel-indicator is not supported")

num_classes = self._cm.num_classes
y_pred_oh = F.one_hot(y_pred.long(), num_classes).float().to(self._device)
self._cm.update((y_pred_oh, y.long().to(self._device)))

def compute(self) -> float:
if self._cm.confusion_matrix.sum() == 0:
raise NotComputableError("CohenKappa must have at least one example before it can be computed.")
conf = self._cm.compute().to(dtype=self._double_dtype)
return _kappa_from_conf(conf, self._weights)


class CohenKappa(Metric):
"""Compute different types of Cohen's Kappa: Non-Weighted, Linear, Quadratic.

When ``num_classes`` is provided, accumulates a running confusion matrix via
:class:`~ignite.metrics.confusion_matrix.ConfusionMatrix` (memory-efficient, no raw tensor buffering).
When ``num_classes`` is ``None`` (default), buffers predictions and targets via
:class:`~ignite.metrics.EpochMetric` and infers the number of classes from data.

Args:
output_transform: a callable that is used to transform the
Expand All @@ -19,19 +164,17 @@ class CohenKappa(EpochMetric):
you want to compute the metric with respect to one of the outputs.
weights: a string is used to define the type of Cohen's Kappa whether Non-Weighted or Linear
or Quadratic. Default, None.
check_compute_fn: Default False. If True, `cohen_kappa_score
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.cohen_kappa_score.html>`_
is run on the first batch of data to ensure there are
no issues. User will be warned in case there are any issues computing the function.
device: optional device specification for internal storage.
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-output as ``(y_pred_a, y_pred_b)``
Alternatively, ``output_transform`` can be used to handle this.
num_classes: number of classes. If provided, uses a running confusion matrix
(memory-efficient). If ``None``, infers from data at compute time (backward-compatible default).

Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_transform`` can be added
to the metric to transform the output into the form expected by the metric.

.. include:: defaults.rst
Expand All @@ -52,37 +195,52 @@ class CohenKappa(EpochMetric):

.. versionchanged:: 0.5.1
``skip_unrolling`` argument is added.

.. versionchanged:: 0.6.0
Replaced scikit-learn dependency with a native PyTorch implementation.

.. versionchanged:: 0.6.2
Added ``num_classes`` argument; routes to a running-confusion-matrix backend when provided.
"""

def __init__(
self,
output_transform: Callable = lambda x: x,
weights: Literal["linear", "quadratic"] | None = None,
check_compute_fn: bool = False,
device: str | torch.device = torch.device("cpu"),
skip_unrolling: bool = False,
num_classes: int | None = None,
):
try:
from sklearn.metrics import cohen_kappa_score # noqa: F401
except ImportError:
raise ModuleNotFoundError("This contrib module requires scikit-learn to be installed.")
if weights not in (None, "linear", "quadratic"):
raise ValueError("Kappa Weighting type must be None or linear or quadratic.")

# initialize weights
self.weights: Literal["linear", "quadratic"] | None = weights

super().__init__(
self._cohen_kappa_score,
output_transform=output_transform,
check_compute_fn=check_compute_fn,
device=device,
skip_unrolling=skip_unrolling,
)

def _cohen_kappa_score(self, y_targets: torch.Tensor, y_preds: torch.Tensor) -> float:
from sklearn.metrics import cohen_kappa_score

y_true = y_targets.cpu().numpy()
y_pred = y_preds.cpu().numpy()
return cohen_kappa_score(y_true, y_pred, weights=self.weights)
if num_classes is not None:
self._impl: Metric = _CohenKappaConfusionMatrix(
num_classes=num_classes,
weights=weights,
output_transform=output_transform,
device=device,
skip_unrolling=skip_unrolling,
)
else:
self._impl = _CohenKappaEpochMetric(
weights=weights,
output_transform=output_transform,
device=device,
skip_unrolling=skip_unrolling,
)

super().__init__(output_transform=output_transform, device=device, skip_unrolling=skip_unrolling)

@reinit__is_reduced
def reset(self) -> None:
self._impl.reset()

@reinit__is_reduced
def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
self._impl.update(output)

def compute(self) -> float:
return self._impl.compute()
111 changes: 68 additions & 43 deletions tests/ignite/metrics/test_cohen_kappa.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
from unittest.mock import patch

import pytest
import sklearn
import torch
from sklearn.metrics import cohen_kappa_score

Expand All @@ -14,55 +12,15 @@
torch.manual_seed(12)


@pytest.fixture()
def mock_no_sklearn():
with patch.dict("sys.modules", {"sklearn.metrics": None}):
yield sklearn


def test_no_sklearn(mock_no_sklearn):
with pytest.raises(ModuleNotFoundError, match=r"This contrib module requires scikit-learn to be installed."):
CohenKappa()


def test_no_update():
ck = CohenKappa()

with pytest.raises(
NotComputableError, match=r"EpochMetric must have at least one example before it can be computed"
NotComputableError, match=r"CohenKappa must have at least one example before it can be computed"
):
ck.compute()


def test_input_types():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here, why do you remove these tests?

ck = CohenKappa()
ck.reset()
output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
ck.update(output1)

with pytest.raises(ValueError, match=r"Incoherent types between input y_pred and stored predictions"):
ck.update((torch.randint(0, 5, size=(4, 3)), torch.randint(0, 2, size=(4, 3))))

with pytest.raises(ValueError, match=r"Incoherent types between input y and stored targets"):
ck.update((torch.rand(4, 3), torch.randint(0, 2, size=(4, 3)).to(torch.int32)))

with pytest.raises(ValueError, match=r"Incoherent types between input y_pred and stored predictions"):
ck.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10, 5)).long()))


def test_check_shape():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you remove these tests?

ck = CohenKappa()

with pytest.raises(ValueError, match=r"Predictions should be of shape"):
ck._check_shape((torch.tensor(0), torch.tensor(0)))

with pytest.raises(ValueError, match=r"Predictions should be of shape"):
ck._check_shape((torch.rand(4, 3, 1), torch.rand(4, 3)))

with pytest.raises(ValueError, match=r"Targets should be of shape"):
ck._check_shape((torch.rand(4, 3), torch.rand(4, 3, 1)))


def test_cohen_kappa_wrong_weights_type():
with pytest.raises(ValueError, match=r"Kappa Weighting type must be"):
ck = CohenKappa(weights=7)
Expand Down Expand Up @@ -329,3 +287,70 @@ def _test_distrib_xla_nprocs(index):
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)


# --- num_classes path tests ---


def test_num_classes_no_update():
ck = CohenKappa(num_classes=3)
with pytest.raises(
NotComputableError, match=r"CohenKappa must have at least one example before it can be computed"
):
ck.compute()


@pytest.mark.parametrize("weights", [None, "linear", "quadratic"])
def test_num_classes_matches_dynamic(weights, available_device):
torch.manual_seed(42)
y_pred = torch.randint(0, 4, size=(60,)).long()
y = torch.randint(0, 4, size=(60,)).long()
batch_size = 10

ck_dynamic = CohenKappa(weights=weights, device=available_device)
ck_fixed = CohenKappa(weights=weights, device=available_device, num_classes=4)

for ck in (ck_dynamic, ck_fixed):
ck.reset()
for i in range(60 // batch_size):
idx = i * batch_size
ck.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))

assert ck_dynamic.compute() == pytest.approx(ck_fixed.compute())


@pytest.mark.parametrize("weights", [None, "linear", "quadratic"])
def test_num_classes_single_batch(weights, available_device):
torch.manual_seed(0)
y_pred = torch.randint(0, 3, size=(30,)).long()
y = torch.randint(0, 3, size=(30,)).long()

ck = CohenKappa(weights=weights, device=available_device, num_classes=3)
ck.reset()
ck.update((y_pred, y))
res = ck.compute()

assert isinstance(res, float)
assert cohen_kappa_score(y.numpy(), y_pred.numpy(), weights=weights) == pytest.approx(res)


def test_num_classes_multilabel_inputs():
ck = CohenKappa(num_classes=4)
with pytest.raises(ValueError, match=r"multilabel-indicator is not supported"):
ck.reset()
ck.update((torch.randint(0, 2, size=(10, 4)).long(), torch.randint(0, 2, size=(10, 4)).long()))
ck.compute()


def test_num_classes_squeeze_n1():
torch.manual_seed(7)
y_pred = torch.randint(0, 2, size=(20, 1)).long()
y = torch.randint(0, 2, size=(20, 1)).long()

ck = CohenKappa(num_classes=2)
ck.reset()
ck.update((y_pred, y))
res = ck.compute()

assert isinstance(res, float)
assert cohen_kappa_score(y.squeeze().numpy(), y_pred.squeeze().numpy()) == pytest.approx(res)
Loading