Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ignite.metrics.gan.fid import FID
from ignite.metrics.gan.inception_score import InceptionScore
from ignite.metrics.gpu_info import GpuInfo
from ignite.metrics.harmonic_mean import HarmonicMean
from ignite.metrics.hsic import HSIC
from ignite.metrics.js_divergence import JSDivergence
from ignite.metrics.kl_divergence import KLDivergence
Expand Down Expand Up @@ -75,6 +76,7 @@
"JaccardIndex",
"JSDivergence",
"KLDivergence",
"HarmonicMean",
"HSIC",
"MaximumMeanDiscrepancy",
"MultiLabelConfusionMatrix",
Expand Down
55 changes: 55 additions & 0 deletions ignite/metrics/harmonic_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Callable, Union
import torch
from ignite.metrics.metric import Metric, sync_all_reduce
from ignite.exceptions import NotComputableError

class HarmonicMean(Metric):
"""
Computes the harmonic mean.

.. math::
H = \\frac{n}{\\sum_{i=1}^n (1 / x_i)}

where :math:`x_i` are the individual values and :math:`n` is the total count of values.

Args:
output_transform: A callable that transforms the engine's output into the
expected format.
device: Specifies which device updates are accumulated on.

Example:
.. code-block:: python

metric = HarmonicMean()
metric.attach(evaluator, "harmonic_mean")

.. versionadded:: 0.5.4
"""

def __init__(self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")):
super(HarmonicMean, self).__init__(output_transform=output_transform, device=device)
self.reset()

def reset(self) -> None:
super(HarmonicMean, self).reset()
self._sum_reciprocal = torch.tensor(0.0, device=self._device)
self._num_examples = 0

def update(self, output: torch.Tensor) -> None:
if not isinstance(output, torch.Tensor):
output = torch.as_tensor(output)

values = output.detach().reshape(-1).to(self._device)

if torch.any(values <= 0):
raise ValueError("Harmonic mean is only defined for positive values.")

self._sum_reciprocal += torch.sum(1.0 / values)
self._num_examples += values.numel()

@sync_all_reduce("_sum_reciprocal", "_num_examples")
def compute(self) -> float:
if self._num_examples == 0:
raise NotComputableError("HarmonicMean must have at least one example.")

return (self._num_examples / self._sum_reciprocal).item()
56 changes: 56 additions & 0 deletions tests/ignite/metrics/test_harmonic_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import pytest
from scipy.stats import hmean
from ignite.metrics import HarmonicMean
from ignite.exceptions import NotComputableError

def test_harmonic_mean_basic():
device = "cuda" if torch.cuda.is_available() else "cpu"
metric = HarmonicMean(device=device)

data = torch.tensor([1.0, 2.0, 4.0], device=device)
metric.update(data)
result = metric.compute()

expected = hmean([1.0, 2.0, 4.0])
assert result == pytest.approx(expected)

def test_harmonic_mean_multiple_updates():
metric = HarmonicMean()

metric.update(torch.tensor([1.0, 10.0]))

metric.update(torch.tensor([5.0, 2.0]))

result = metric.compute()
expected = hmean([1.0, 10.0, 5.0, 2.0])
assert result == pytest.approx(expected)

def test_harmonic_mean_invalid_input():
metric = HarmonicMean()

# Test for zero or negative values.
with pytest.raises(ValueError, match="Harmonic mean is only defined for positive values."):
metric.update(torch.tensor([1.0, 0.0, -2.0]))

def test_not_computable():
metric = HarmonicMean()
with pytest.raises(NotComputableError):
metric.compute()

def test_reset():
metric = HarmonicMean()
metric.update(torch.tensor([1.0, 2.0]))
metric.reset()
with pytest.raises(NotComputableError):
metric.compute()

def test_harmonic_mean_tensor_shape():
metric = HarmonicMean()

data = torch.tensor([[1.0, 2.0], [4.0, 8.0]])
metric.update(data)

result = metric.compute()
expected = hmean([1.0, 2.0, 4.0, 8.0])
assert result == pytest.approx(expected)