diff --git a/ignite/metrics/regression/spearman_correlation.py b/ignite/metrics/regression/spearman_correlation.py index 755698b6f867..0da00d316a40 100644 --- a/ignite/metrics/regression/spearman_correlation.py +++ b/ignite/metrics/regression/spearman_correlation.py @@ -10,13 +10,58 @@ from ignite.metrics.regression._base import _check_output_shapes, _check_output_types +def _get_ranks(x: Tensor) -> Tensor: + """Calculates ranks with average method for ties natively in PyTorch.""" + n = x.size(0) + # Get sorted indices and the inverse mapping + sorter = torch.argsort(x) + inv_sorter = torch.empty(n, dtype=torch.long, device=x.device) + inv_sorter[sorter] = torch.arange(n, device=x.device) + + x_sorted = x[sorter] + # Find ties + obs = torch.cat([torch.tensor([True], device=x.device), x_sorted[1:] != x_sorted[:-1]]) + dense_ranks = torch.cumsum(obs, dim=0) + + # Calculate average ranks for ties + count = torch.cat([torch.nonzero(obs).flatten(), torch.tensor([n], device=x.device)]) + repetitions = count[1:] - count[:-1] + + # Use cumsum of repetitions to find the range of ranks for each unique value + right = torch.cumsum(repetitions, dim=0) + left = right - repetitions + 1 + avg_ranks = (left + right).double() / 2.0 + + # Map back to original order + return avg_ranks[dense_ranks - 1][inv_sorter] + + def _spearman_r(predictions: Tensor, targets: Tensor) -> float: - from scipy.stats import spearmanr + preds_flat = predictions.flatten() + targets_flat = targets.flatten() + + if torch.isnan(preds_flat).any() or torch.isnan(targets_flat).any(): + return float("nan") - np_preds = predictions.flatten().cpu().numpy() - np_targets = targets.flatten().cpu().numpy() - r = spearmanr(np_preds, np_targets).statistic - return r + # Native PyTorch Ranking + r_preds = _get_ranks(preds_flat) + r_targets = _get_ranks(targets_flat) + + # Correlation of ranks (Pearson Correlation) + mu_x = torch.mean(r_preds) + mu_y = torch.mean(r_targets) + + diff_x = r_preds - mu_x + diff_y = r_targets - mu_y + + norm_x = torch.norm(diff_x, 2) + norm_y = torch.norm(diff_y, 2) + + if norm_x == 0 or norm_y == 0: + return float("nan") + + r = torch.sum(diff_x * diff_y) / (norm_x * norm_y) + return r.item() class SpearmanRankCorrelation(EpochMetric): @@ -30,8 +75,7 @@ class SpearmanRankCorrelation(EpochMetric): where :math:`A` and :math:`P` are the ground truth and predicted value, and :math:`R[X]` is the ranking value of :math:`X`. - The computation of this metric is implemented with - `scipy.stats.spearmanr `_. + The computation of this metric is implemented natively in PyTorch. - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. @@ -76,6 +120,10 @@ class SpearmanRankCorrelation(EpochMetric): 0.7142857142857143 .. versionadded:: 0.5.2 + + .. versionchanged:: 0.5.5 + Implementation updated to use a native PyTorch computation for rank calculation and + correlation, removing the dependency on SciPy. """ def __init__( @@ -85,11 +133,6 @@ def __init__( device: str | torch.device = torch.device("cpu"), skip_unrolling: bool = False, ) -> None: - try: - from scipy.stats import spearmanr # noqa: F401 - except ImportError: - raise ModuleNotFoundError("This module requires scipy to be installed.") - super().__init__(_spearman_r, output_transform, check_compute_fn, device, skip_unrolling) def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: @@ -99,10 +142,10 @@ def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: if y.ndim == 1: y = y.unsqueeze(1) - _check_output_shapes(output) - _check_output_types(output) + _check_output_shapes((y_pred, y)) + _check_output_types((y_pred, y)) - super().update(output) + super().update((y_pred, y)) def compute(self) -> float: if len(self._predictions) < 1 or len(self._targets) < 1: diff --git a/tests/ignite/metrics/regression/test_spearman_correlation.py b/tests/ignite/metrics/regression/test_spearman_correlation.py index fff222b9af40..1649ec17fb27 100644 --- a/tests/ignite/metrics/regression/test_spearman_correlation.py +++ b/tests/ignite/metrics/regression/test_spearman_correlation.py @@ -8,6 +8,7 @@ from ignite.engine import Engine from ignite.exceptions import NotComputableError from ignite.metrics.regression import SpearmanRankCorrelation +from ignite.metrics.regression.spearman_correlation import _get_ranks def test_zero_sample(): @@ -67,8 +68,10 @@ def test_spearman_correlation(available_device): all_preds.append(x) all_targets.append(ground_truth) - pred_cat = torch.cat(all_preds).numpy() - target_cat = torch.cat(all_targets).numpy() + pred_cat = torch.cat(all_preds).cpu().numpy() + target_cat = torch.cat(all_targets).cpu().numpy() + + # Convert only for computing the expected value expected = spearmanr(pred_cat, target_cat).statistic assert m.compute() == pytest.approx(expected, rel=1e-4) @@ -105,7 +108,7 @@ def update_fn(engine: Engine, batch): corr = engine.run(data, max_epochs=1).metrics["spearman_corr"] # Convert only for computing the expected value - expected = spearmanr(y_pred.numpy().ravel(), y.numpy().ravel()).statistic + expected = spearmanr(y_pred.cpu().numpy().ravel(), y.cpu().numpy().ravel()).statistic assert pytest.approx(expected, rel=2e-4) == corr @@ -182,3 +185,33 @@ def test_integration(self, n_epochs: int): np_ans = spearmanr(np_y_pred, np_y).statistic assert pytest.approx(np_ans, rel=tol) == res + + +def test_nan_inputs(): + metric = SpearmanRankCorrelation() + + y_pred = torch.tensor([1.0, float("nan"), 3.0]) + y = torch.tensor([1.0, 2.0, 3.0]) + + metric.update((y_pred, y)) + assert torch.isnan(torch.tensor(metric.compute())) + + +def test_constant_inputs(): + metric = SpearmanRankCorrelation() + + y_pred = torch.tensor([5.0, 5.0, 5.0, 5.0]) + y = torch.tensor([1.0, 2.0, 3.0, 4.0]) + + metric.update((y_pred, y)) + assert torch.isnan(torch.tensor(metric.compute())) + + +def test_average_rank_logic(): + x = torch.tensor([10.0, 20.0, 20.0, 30.0]) + + ranks = _get_ranks(x) + + expected = torch.tensor([1.0, 2.5, 2.5, 4.0], dtype=torch.double) + + assert torch.allclose(ranks, expected)