diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index c26660d58fe3..0f80bf282dd3 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -37,6 +37,7 @@ from ignite.metrics.psnr import PSNR from ignite.metrics.recall import Recall from ignite.metrics.rec_sys.hitrate import HitRate +from ignite.metrics.rec_sys.mrr import MRR from ignite.metrics.roc_auc import ROC_AUC, RocCurve from ignite.metrics.root_mean_squared_error import RootMeanSquaredError from ignite.metrics.running_average import RunningAverage @@ -104,4 +105,5 @@ "CommonObjectDetectionMetrics", "coco_tensor_list_to_dict_list", "HitRate", + "MRR", ] diff --git a/ignite/metrics/rec_sys/__init__.py b/ignite/metrics/rec_sys/__init__.py index f6f37785cb4e..6876625f6d98 100644 --- a/ignite/metrics/rec_sys/__init__.py +++ b/ignite/metrics/rec_sys/__init__.py @@ -1 +1,2 @@ from ignite.metrics.rec_sys.hitrate import HitRate +from ignite.metrics.rec_sys.mrr import MRR diff --git a/ignite/metrics/rec_sys/mrr.py b/ignite/metrics/rec_sys/mrr.py new file mode 100644 index 000000000000..fcc3fed27ad9 --- /dev/null +++ b/ignite/metrics/rec_sys/mrr.py @@ -0,0 +1,167 @@ +from typing import Callable + +import torch + +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + +__all__ = ["MRR"] + + +class MRR(Metric): + r"""Calculates the Mean Reciprocal Rank (MRR) at `k` for Recommendation Systems. + + MRR measures the average of the reciprocal of the rank of the first relevant item + in the predicted list. It is widely used in retrieval systems, recommendation systems, + and RAG pipelines. + + .. math:: \text{MRR}@K = \frac{1}{N} \sum_{i=1}^{N} \frac{1}{\text{rank}_i} + + where :math:`\text{rank}_i` is the rank (1-indexed) of the first relevant item + in the top-K predictions for user :math:`i`. If no relevant item is found in the + top-K, the reciprocal rank for that user is 0. + + - ``update`` must receive output of the form ``(y_pred, y)``. + - ``y_pred`` is expected to be raw logits or probability score for each item in the catalog. + - ``y`` is expected to be binary (only 0s and 1s) values where `1` indicates relevant item. + - ``y_pred`` and ``y`` are only allowed shape :math:`(batch, num\_items)`. + - returns a list of MRR ordered by the sorted values of ``top_k``. + + Args: + top_k: a list of sorted positive integers that specifies `k` for calculating MRR@top-k. + ignore_zero_hits: if True, users with no relevant items (ground truth tensor being all zeros) + are ignored in computation of MRR. If set False, such users are counted as having + reciprocal rank of 0. By default, True. + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. + The output is expected to be a tuple `(prediction, target)` + where `prediction` and `target` are tensors + of shape ``(batch, num_items)``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + skip_unrolling: specifies whether input should be unrolled or not before being + processed. Should be true for multi-output models.. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in the format of + ``(y_pred, y)``. If not, ``output_tranform`` can be added + to the metric to transform the output into the form expected by the metric. + + For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. + + .. include:: defaults.rst + :start-after: :orphan: + + ignore_zero_hits=True case + + .. testcode:: 1 + + metric = MRR(top_k=[1, 2, 3, 4]) + metric.attach(default_evaluator,"mrr") + y_pred=torch.Tensor([ + [4.0, 2.0, 3.0, 1.0], + [1.0, 2.0, 3.0, 4.0] + ]) + y_true=torch.Tensor([ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0] + ]) + state = default_evaluator.run([(y_pred, y_true)]) + print(state.metrics["mrr"]) + + .. testoutput:: 1 + + [0.0, 0.5, 0.5, 0.5] + + ignore_zero_hits=False case + + .. testcode:: 2 + + metric = MRR(top_k=[1, 2, 3, 4], ignore_zero_hits=False) + metric.attach(default_evaluator,"mrr") + y_pred=torch.Tensor([ + [4.0, 2.0, 3.0, 1.0], + [1.0, 2.0, 3.0, 4.0] + ]) + y_true=torch.Tensor([ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0] + ]) + state = default_evaluator.run([(y_pred, y_true)]) + print(state.metrics["mrr"]) + + .. testoutput:: 2 + + [0.0, 0.25, 0.25, 0.25] + + .. versionadded:: 0.6.0 + """ + + required_output_keys = ("y_pred", "y") + _state_dict_all_req_keys = ("_sum_reciprocal_ranks_per_k", "_num_examples") + + def __init__( + self, + top_k: list[int], + ignore_zero_hits: bool = True, + output_transform: Callable = lambda x: x, + device: str | torch.device = torch.device("cpu"), + skip_unrolling: bool = False, + ): + if any(k <= 0 for k in top_k): + raise ValueError(" top_k must be list of positive integers only.") + + self.top_k = sorted(top_k) + self.ignore_zero_hits = ignore_zero_hits + super(MRR, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling) + + @reinit__is_reduced + def reset(self) -> None: + self._sum_reciprocal_ranks_per_k = torch.zeros(len(self.top_k), device=self._device) + self._num_examples = 0 + + @reinit__is_reduced + def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: + if len(output) != 2: + raise ValueError(f"output should be in format `(y_pred,y)` but got tuple of {len(output)} tensors.") + + y_pred, y = output + if y_pred.shape != y.shape: + raise ValueError(f"y_pred and y must be in the same shape, got {y_pred.shape} != {y.shape}.") + + if self.ignore_zero_hits: + valid_mask = torch.any(y > 0, dim=-1) + y_pred = y_pred[valid_mask] + y = y[valid_mask] + + if y.shape[0] == 0: + return + + max_k = self.top_k[-1] + _, indices = torch.topk(y_pred, k=max_k, dim=-1) + + hits_at_max_k = torch.gather(y, dim=-1, index=indices) + + for i, k in enumerate(self.top_k): + hits_at_k = hits_at_max_k[:, :k] + has_hit = torch.any(hits_at_k > 0, dim=-1) + first_hit_pos = torch.argmax((hits_at_k > 0).int(), dim=-1) + reciprocal_rank = torch.where( + has_hit, + 1.0 / (first_hit_pos.float() + 1), + torch.zeros_like(first_hit_pos, dtype=torch.float), + ) + self._sum_reciprocal_ranks_per_k[i] += reciprocal_rank.sum().to(self._device) + + self._num_examples += y.shape[0] + + @sync_all_reduce("_sum_reciprocal_ranks_per_k", "_num_examples") + def compute(self) -> list[float]: + if self._num_examples == 0: + raise NotComputableError("MRR must have at least one example.") + + rates = (self._sum_reciprocal_ranks_per_k / self._num_examples).tolist() + return rates diff --git a/tests/ignite/metrics/rec_sys/test_mrr_metric.py b/tests/ignite/metrics/rec_sys/test_mrr_metric.py new file mode 100644 index 000000000000..dcc8124fb0d3 --- /dev/null +++ b/tests/ignite/metrics/rec_sys/test_mrr_metric.py @@ -0,0 +1,230 @@ +import pytest +import torch +import numpy as np + +import ignite.distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics.rec_sys.mrr import MRR + + +def manual_mrr( + y_pred: np.ndarray, + y: np.ndarray, + top_k: list[int], + ignore_zero_hits: bool = True, +) -> list[float]: + """Manual implementation of MRR using numpy for verification.""" + sorted_top_k = sorted(top_k) + + if ignore_zero_hits: + valid_mask = np.any(y > 0, axis=-1) + y_pred = y_pred[valid_mask] + y = y[valid_mask] + + n_samples = y.shape[0] + if n_samples == 0: + raise ValueError("No valid samples for manual MRR computation.") + + sorted_indices = np.argsort(-y_pred, axis=-1) + + results = [] + for k in sorted_top_k: + k_indices = sorted_indices[:, :k] + rr_sum = 0.0 + for i in range(n_samples): + relevance = y[i, k_indices[i]] + hits = np.where(relevance > 0)[0] + if len(hits) > 0: + rr_sum += 1.0 / (hits[0] + 1) + results.append(rr_sum / n_samples) + + return results + + +def test_zero_sample(): + metric = MRR(top_k=[1, 5]) + with pytest.raises(NotComputableError, match=r"MRR must have at least one example"): + metric.compute() + + +def test_shape_mismatch(): + metric = MRR(top_k=[1]) + y_pred = torch.randn(4, 10) + y = torch.ones(4, 5) # Mismatched items count + with pytest.raises(ValueError, match="y_pred and y must be in the same shape"): + metric.update((y_pred, y)) + + +def test_invalid_top_k(): + with pytest.raises(ValueError, match="top_k must be list of positive integers"): + MRR(top_k=[0]) + with pytest.raises(ValueError, match="top_k must be list of positive integers"): + MRR(top_k=[-1, 5]) + + +@pytest.mark.parametrize("top_k", [[1], [1, 2, 4]]) +@pytest.mark.parametrize("ignore_zero_hits", [True, False]) +def test_compute(top_k, ignore_zero_hits, available_device): + metric = MRR( + top_k=top_k, + ignore_zero_hits=ignore_zero_hits, + device=available_device, + ) + + y_pred = torch.tensor([[4.0, 2.0, 3.0, 1.0], [1.0, 2.0, 3.0, 4.0]]) + y_true = torch.tensor([[0, 0, 1.0, 1.0], [0, 0, 0.0, 0.0]]) + + metric.update((y_pred, y_true)) + res = metric.compute() + + expected = manual_mrr( + y_pred.numpy(), + y_true.numpy(), + top_k, + ignore_zero_hits=ignore_zero_hits, + ) + + assert isinstance(res, list) + assert len(res) == len(top_k) + np.testing.assert_allclose(res, expected) + + +def test_known_values(): + """Test with manually computed expected values.""" + metric = MRR(top_k=[1, 2, 3, 4]) + + # User 1: y_pred=[4,2,3,1] -> sorted indices [0,2,1,3] + # y=[0,0,1,1] -> relevance at sorted positions: [0,1,0,1] + # MRR@1: no hit -> 0 + # MRR@2: first hit at position 2 -> 1/2 = 0.5 + # MRR@3: first hit at position 2 -> 1/2 = 0.5 + # MRR@4: first hit at position 2 -> 1/2 = 0.5 + y_pred = torch.tensor([[4.0, 2.0, 3.0, 1.0]]) + y_true = torch.tensor([[0.0, 0.0, 1.0, 1.0]]) + + metric.update((y_pred, y_true)) + res = metric.compute() + + assert res == pytest.approx([0.0, 0.5, 0.5, 0.5]) + + +def test_perfect_prediction(): + """Test when the most relevant item is the top prediction.""" + metric = MRR(top_k=[1, 3]) + + y_pred = torch.tensor([[5.0, 1.0, 2.0]]) + y_true = torch.tensor([[1.0, 0.0, 0.0]]) + + metric.update((y_pred, y_true)) + res = metric.compute() + + assert res == pytest.approx([1.0, 1.0]) + + +def test_multiple_batches(): + """Test accumulation across multiple update calls.""" + metric = MRR(top_k=[2]) + + # Batch 1: hit at position 2 -> RR = 0.5 + y_pred1 = torch.tensor([[4.0, 2.0, 3.0, 1.0]]) + y_true1 = torch.tensor([[0.0, 0.0, 1.0, 1.0]]) + metric.update((y_pred1, y_true1)) + + # Batch 2: hit at position 1 -> RR = 1.0 + y_pred2 = torch.tensor([[5.0, 1.0, 2.0, 3.0]]) + y_true2 = torch.tensor([[1.0, 0.0, 0.0, 0.0]]) + metric.update((y_pred2, y_true2)) + + res = metric.compute() + # MRR = (0.5 + 1.0) / 2 = 0.75 + assert res == pytest.approx([0.75]) + + +def test_accumulator_detached(available_device): + metric = MRR(top_k=[1], device=available_device) + y_pred = torch.randn(4, 5, requires_grad=True) + y = torch.randint(0, 2, (4, 5)).float() + metric.update((y_pred, y)) + + assert metric._sum_reciprocal_ranks_per_k.requires_grad is False + assert metric._sum_reciprocal_ranks_per_k.is_leaf is True + + +def test_all_zero_targets_ignore(): + metric = MRR(top_k=[1, 3], ignore_zero_hits=True) + + y_pred = torch.randn(4, 5) + y = torch.zeros(4, 5) + + metric.update((y_pred, y)) + + with pytest.raises(NotComputableError): + metric.compute() + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_integration(self): + n_iters = 10 + batch_size = 4 + num_items = 20 + top_k = [1, 5] + + rank = idist.get_rank() + torch.manual_seed(42 + rank) + device = idist.device() + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + all_y_true = torch.randint(0, 2, (n_iters * batch_size, num_items)).float().to(device) + all_y_pred = torch.randn((n_iters * batch_size, num_items)).to(device) + + for ignore_zero_hits in [True, False]: + engine = Engine( + lambda e, i: ( + all_y_pred[i * batch_size : (i + 1) * batch_size], + all_y_true[i * batch_size : (i + 1) * batch_size], + ) + ) + m = MRR( + top_k=top_k, + ignore_zero_hits=ignore_zero_hits, + device=metric_device, + ) + m.attach(engine, "mrr") + + engine.run(range(n_iters), max_epochs=1) + + global_y_true = idist.all_gather(all_y_true).cpu().numpy() + global_y_pred = idist.all_gather(all_y_pred).cpu().numpy() + + res = engine.state.metrics["mrr"] + + true_res = manual_mrr( + global_y_pred, + global_y_true, + top_k, + ignore_zero_hits=ignore_zero_hits, + ) + + assert isinstance(res, list) + assert res == pytest.approx(true_res) + + engine.state.metrics.clear() + + def test_accumulator_device(self): + device = idist.device() + metric = MRR(top_k=[1, 5], device=device) + + assert metric._device == device + assert metric._sum_reciprocal_ranks_per_k.device == device + + y_pred = torch.randn(2, 10) + y = torch.zeros(2, 10) + metric.update((y_pred, y)) + + assert metric._sum_reciprocal_ranks_per_k.device == device