diff --git a/src/torchmetrics/functional/segmentation/utils.py b/src/torchmetrics/functional/segmentation/utils.py index 6610cdc8c19..5b6f6eb4418 100644 --- a/src/torchmetrics/functional/segmentation/utils.py +++ b/src/torchmetrics/functional/segmentation/utils.py @@ -337,7 +337,7 @@ def distance_transform( dis_col = (j1.view(-1, 1) - j0.view(1, -1)).abs() # # calculate distance - h, _ = x.shape + h, w = x.shape if metric == "euclidean": dis = ((sampling[0] * dis_row) ** 2 + (sampling[1] * dis_col) ** 2).sqrt() if metric == "chessboard": @@ -348,7 +348,7 @@ def distance_transform( # select only the closest distance mindis, _ = torch.min(dis, dim=1) z = torch.zeros_like(x).view(-1) - z[i1 * h + j1] = mindis + z[i1 * w + j1] = mindis return z.view(x.shape) if not _SCIPY_AVAILABLE: diff --git a/tests/unittests/segmentation/test_hausdorff_distance.py b/tests/unittests/segmentation/test_hausdorff_distance.py index 02db67c7cb6..7491995a788 100644 --- a/tests/unittests/segmentation/test_hausdorff_distance.py +++ b/tests/unittests/segmentation/test_hausdorff_distance.py @@ -141,3 +141,14 @@ def test_hausdorff_distance_functional(self, inputs, input_format, distance_metr def test_hausdorff_distance_raises_error(): """Check that metric raises appropriate errors.""" preds, target = _inputs1 + + +def test_hausdorff_distance_non_square_input(): + """Check that functional Hausdorff distance works for height > width inputs.""" + preds = torch.randint(0, 2, (1, 1, 11, 10)) + target = torch.randint(0, 2, (1, 1, 11, 10)) + + result = hausdorff_distance(preds, target, num_classes=1) + reference = monai_hausdorff_distance(preds, target) + + assert torch.allclose(result, reference.to(result)) diff --git a/tests/unittests/segmentation/test_utils.py b/tests/unittests/segmentation/test_utils.py index 30a6da7d954..0ee7babb966 100644 --- a/tests/unittests/segmentation/test_utils.py +++ b/tests/unittests/segmentation/test_utils.py @@ -166,6 +166,20 @@ def test_distance_transform(case, metric, device): assert torch.allclose(distance.cpu(), torch.from_numpy(scidistance).to(distance.dtype)) +@pytest.mark.parametrize("metric", ["euclidean", "chessboard", "taxicab"]) +def test_distance_transform_non_square_input(metric): + """Check that non-square inputs with height > width are handled correctly.""" + case = torch.randint(0, 2, (11, 10)) + + distance = distance_transform(case, metric=metric) + if metric == "euclidean": + scidistance = scidistance_transform_edt(case) + else: + scidistance = scidistance_transform_cdt(case, metric=metric) + + assert torch.allclose(distance, torch.from_numpy(scidistance).to(distance.dtype)) + + @pytest.mark.parametrize("dim", [2, 3]) @pytest.mark.parametrize("spacing", [1, 2]) def test_neighbour_table(dim, spacing):