Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 187 additions & 30 deletions ignite/metrics/cohen_kappa.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,160 @@
from collections.abc import Callable
from functools import partial
from typing import Literal

import torch
import torch.nn.functional as F

from ignite.exceptions import NotComputableError
from ignite.metrics.confusion_matrix import ConfusionMatrix
from ignite.metrics.epoch_metric import EpochMetric
from ignite.metrics.metric import Metric, reinit__is_reduced


class CohenKappa(EpochMetric):
"""Compute different types of Cohen's Kappa: Non-Wieghted, Linear, Quadratic.
Accumulating predictions and the ground-truth during an epoch and applying
`sklearn.metrics.cohen_kappa_score <https://scikit-learn.org/stable/modules/
generated/sklearn.metrics.cohen_kappa_score.html>`_ .
def _kappa_from_conf(conf: torch.Tensor, weights: Literal["linear", "quadratic"] | None) -> float:
n = conf.sum()
if n == 0:
raise NotComputableError("CohenKappa cannot be computed on an empty confusion matrix (n == 0).")

n_classes = conf.shape[0]

if weights is None:
p_o = conf.trace() / n
row = conf.sum(dim=1)
col = conf.sum(dim=0)
p_e = (row * col).sum() / (n * n)
else:
idx = torch.arange(n_classes, device=conf.device)
if weights == "linear":
w = torch.abs(idx.unsqueeze(0) - idx.unsqueeze(1)).double()
else:
w = ((idx.unsqueeze(0) - idx.unsqueeze(1)) ** 2).double()

w = w / w.max()
p_o = 1 - (w * conf).sum() / n
row = conf.sum(dim=1)
col = conf.sum(dim=0)
expected = row.unsqueeze(1) * col.unsqueeze(0) / n
p_e = 1 - (w * expected).sum() / n

if (1 - p_e).abs() < 1e-9:
return 1.0 if (p_o - p_e).abs() < 1e-9 else float("nan")

return ((p_o - p_e) / (1 - p_e)).item()


def _cohen_kappa_score(
y_pred: torch.Tensor,
y: torch.Tensor,
weights: Literal["linear", "quadratic"] | None,
) -> float:
num_classes = int(max(y_pred.max().item(), y.max().item())) + 1

cm = ConfusionMatrix(num_classes=num_classes, device=y_pred.device)
y_pred_oh = F.one_hot(y_pred.long(), num_classes).float()
cm.update((y_pred_oh, y.long()))
conf = cm.compute().cpu().double()

return _kappa_from_conf(conf, weights)


class _CohenKappaEpochMetric(EpochMetric):
"""CohenKappa backed by EpochMetric — infers num_classes dynamically from data."""

def __init__(
self,
weights: Literal["linear", "quadratic"] | None,
output_transform: Callable,
device: str | torch.device,
skip_unrolling: bool,
):
super().__init__(
compute_fn=partial(_cohen_kappa_score, weights=weights),
output_transform=output_transform,
check_compute_fn=False,
device=device,
skip_unrolling=skip_unrolling,
)

@reinit__is_reduced
def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()

if y_pred.ndim == 2 and y_pred.shape[1] == 1:
y_pred = y_pred.squeeze(dim=-1)
if y.ndim == 2 and y.shape[1] == 1:
y = y.squeeze(dim=-1)

if y_pred.ndim > 1 or y.ndim > 1:
raise ValueError("multilabel-indicator is not supported")

super().update((y_pred, y))

def compute(self) -> float:
try:
return super().compute()
except NotComputableError:
raise NotComputableError("CohenKappa must have at least one example before it can be computed.")


class _CohenKappaConfusionMatrix(Metric):
"""CohenKappa backed by ConfusionMatrix — requires num_classes at construction time.
Accumulates a running confusion matrix; no raw tensor buffering.
"""

_state_dict_all_req_keys = ("_cm",)

def __init__(
self,
num_classes: int,
weights: Literal["linear", "quadratic"] | None,
output_transform: Callable,
device: str | torch.device,
skip_unrolling: bool,
):
self._weights = weights
self._cm = ConfusionMatrix(
num_classes=num_classes,
output_transform=output_transform,
device=device,
skip_unrolling=skip_unrolling,
)
super().__init__(output_transform=output_transform, device=device, skip_unrolling=skip_unrolling)

@reinit__is_reduced
def reset(self) -> None:
self._cm.reset()

@reinit__is_reduced
def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()

if y_pred.ndim == 2 and y_pred.shape[1] == 1:
y_pred = y_pred.squeeze(dim=-1)
if y.ndim == 2 and y.shape[1] == 1:
y = y.squeeze(dim=-1)

if y_pred.ndim > 1 or y.ndim > 1:
raise ValueError("multilabel-indicator is not supported")

num_classes = self._cm.num_classes
y_pred_oh = F.one_hot(y_pred.long(), num_classes).float().to(self._device)
self._cm.update((y_pred_oh, y.long().to(self._device)))

def compute(self) -> float:
if self._cm.confusion_matrix.sum() == 0:
raise NotComputableError("CohenKappa must have at least one example before it can be computed.")
conf = self._cm.compute().cpu().double()
return _kappa_from_conf(conf, self._weights)


class CohenKappa(Metric):
"""Compute different types of Cohen's Kappa: Non-Weighted, Linear, Quadratic.

When ``num_classes`` is provided, accumulates a running confusion matrix via
:class:`~ignite.metrics.ConfusionMatrix` (memory-efficient, no raw tensor buffering).
When ``num_classes`` is ``None`` (default), buffers predictions and targets via
:class:`~ignite.metrics.EpochMetric` and infers the number of classes from data.

Args:
output_transform: a callable that is used to transform the
Expand All @@ -19,19 +163,17 @@ class CohenKappa(EpochMetric):
you want to compute the metric with respect to one of the outputs.
weights: a string is used to define the type of Cohen's Kappa whether Non-Weighted or Linear
or Quadratic. Default, None.
check_compute_fn: Default False. If True, `cohen_kappa_score
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.cohen_kappa_score.html>`_
is run on the first batch of data to ensure there are
no issues. User will be warned in case there are any issues computing the function.
device: optional device specification for internal storage.
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-output as ``(y_pred_a, y_pred_b)``
Alternatively, ``output_transform`` can be used to handle this.
num_classes: number of classes. If provided, uses a running confusion matrix
(memory-efficient). If ``None``, infers from data at compute time (backward-compatible default).

Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_transform`` can be added
to the metric to transform the output into the form expected by the metric.

.. include:: defaults.rst
Expand All @@ -52,37 +194,52 @@ class CohenKappa(EpochMetric):

.. versionchanged:: 0.5.1
``skip_unrolling`` argument is added.

.. versionchanged:: 0.6.0
Replaced scikit-learn dependency with a native PyTorch implementation.

.. versionchanged:: 0.6.2
Added ``num_classes`` argument; routes to a running-confusion-matrix backend when provided.
"""

def __init__(
self,
output_transform: Callable = lambda x: x,
weights: Literal["linear", "quadratic"] | None = None,
check_compute_fn: bool = False,
device: str | torch.device = torch.device("cpu"),
skip_unrolling: bool = False,
num_classes: int | None = None,
):
try:
from sklearn.metrics import cohen_kappa_score # noqa: F401
except ImportError:
raise ModuleNotFoundError("This contrib module requires scikit-learn to be installed.")
if weights not in (None, "linear", "quadratic"):
raise ValueError("Kappa Weighting type must be None or linear or quadratic.")

# initialize weights
self.weights: Literal["linear", "quadratic"] | None = weights

super().__init__(
self._cohen_kappa_score,
output_transform=output_transform,
check_compute_fn=check_compute_fn,
device=device,
skip_unrolling=skip_unrolling,
)

def _cohen_kappa_score(self, y_targets: torch.Tensor, y_preds: torch.Tensor) -> float:
from sklearn.metrics import cohen_kappa_score

y_true = y_targets.cpu().numpy()
y_pred = y_preds.cpu().numpy()
return cohen_kappa_score(y_true, y_pred, weights=self.weights)
if num_classes is not None:
self._impl: Metric = _CohenKappaConfusionMatrix(
num_classes=num_classes,
weights=weights,
output_transform=output_transform,
device=device,
skip_unrolling=skip_unrolling,
)
else:
self._impl = _CohenKappaEpochMetric(
weights=weights,
output_transform=output_transform,
device=device,
skip_unrolling=skip_unrolling,
)

super().__init__(output_transform=output_transform, device=device, skip_unrolling=skip_unrolling)

@reinit__is_reduced
def reset(self) -> None:
self._impl.reset()

@reinit__is_reduced
def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
self._impl.update(output)

def compute(self) -> float:
return self._impl.compute()
111 changes: 68 additions & 43 deletions tests/ignite/metrics/test_cohen_kappa.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
from unittest.mock import patch

import pytest
import sklearn
import torch
from sklearn.metrics import cohen_kappa_score

Expand All @@ -14,55 +12,15 @@
torch.manual_seed(12)


@pytest.fixture()
def mock_no_sklearn():
with patch.dict("sys.modules", {"sklearn.metrics": None}):
yield sklearn


def test_no_sklearn(mock_no_sklearn):
with pytest.raises(ModuleNotFoundError, match=r"This contrib module requires scikit-learn to be installed."):
CohenKappa()


def test_no_update():
ck = CohenKappa()

with pytest.raises(
NotComputableError, match=r"EpochMetric must have at least one example before it can be computed"
NotComputableError, match=r"CohenKappa must have at least one example before it can be computed"
):
ck.compute()


def test_input_types():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here, why do you remove these tests?

ck = CohenKappa()
ck.reset()
output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
ck.update(output1)

with pytest.raises(ValueError, match=r"Incoherent types between input y_pred and stored predictions"):
ck.update((torch.randint(0, 5, size=(4, 3)), torch.randint(0, 2, size=(4, 3))))

with pytest.raises(ValueError, match=r"Incoherent types between input y and stored targets"):
ck.update((torch.rand(4, 3), torch.randint(0, 2, size=(4, 3)).to(torch.int32)))

with pytest.raises(ValueError, match=r"Incoherent types between input y_pred and stored predictions"):
ck.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10, 5)).long()))


def test_check_shape():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you remove these tests?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I accidentally commited it i was trying to do something with git and it pushed the commited changes

ck = CohenKappa()

with pytest.raises(ValueError, match=r"Predictions should be of shape"):
ck._check_shape((torch.tensor(0), torch.tensor(0)))

with pytest.raises(ValueError, match=r"Predictions should be of shape"):
ck._check_shape((torch.rand(4, 3, 1), torch.rand(4, 3)))

with pytest.raises(ValueError, match=r"Targets should be of shape"):
ck._check_shape((torch.rand(4, 3), torch.rand(4, 3, 1)))


def test_cohen_kappa_wrong_weights_type():
with pytest.raises(ValueError, match=r"Kappa Weighting type must be"):
ck = CohenKappa(weights=7)
Expand Down Expand Up @@ -329,3 +287,70 @@ def _test_distrib_xla_nprocs(index):
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)


# --- num_classes path tests ---


def test_num_classes_no_update():
ck = CohenKappa(num_classes=3)
with pytest.raises(
NotComputableError, match=r"CohenKappa must have at least one example before it can be computed"
):
ck.compute()


@pytest.mark.parametrize("weights", [None, "linear", "quadratic"])
def test_num_classes_matches_dynamic(weights, available_device):
torch.manual_seed(42)
y_pred = torch.randint(0, 4, size=(60,)).long()
y = torch.randint(0, 4, size=(60,)).long()
batch_size = 10

ck_dynamic = CohenKappa(weights=weights, device=available_device)
ck_fixed = CohenKappa(weights=weights, device=available_device, num_classes=4)

for ck in (ck_dynamic, ck_fixed):
ck.reset()
for i in range(60 // batch_size):
idx = i * batch_size
ck.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))

assert ck_dynamic.compute() == pytest.approx(ck_fixed.compute())


@pytest.mark.parametrize("weights", [None, "linear", "quadratic"])
def test_num_classes_single_batch(weights, available_device):
torch.manual_seed(0)
y_pred = torch.randint(0, 3, size=(30,)).long()
y = torch.randint(0, 3, size=(30,)).long()

ck = CohenKappa(weights=weights, device=available_device, num_classes=3)
ck.reset()
ck.update((y_pred, y))
res = ck.compute()

assert isinstance(res, float)
assert cohen_kappa_score(y.numpy(), y_pred.numpy(), weights=weights) == pytest.approx(res)


def test_num_classes_multilabel_inputs():
ck = CohenKappa(num_classes=4)
with pytest.raises(ValueError, match=r"multilabel-indicator is not supported"):
ck.reset()
ck.update((torch.randint(0, 2, size=(10, 4)).long(), torch.randint(0, 2, size=(10, 4)).long()))
ck.compute()


def test_num_classes_squeeze_n1():
torch.manual_seed(7)
y_pred = torch.randint(0, 2, size=(20, 1)).long()
y = torch.randint(0, 2, size=(20, 1)).long()

ck = CohenKappa(num_classes=2)
ck.reset()
ck.update((y_pred, y))
res = ck.compute()

assert isinstance(res, float)
assert cohen_kappa_score(y.squeeze().numpy(), y_pred.squeeze().numpy()) == pytest.approx(res)