-
-
Notifications
You must be signed in to change notification settings - Fork 694
fix(metrics): use Welford's algorithm in PearsonCorrelation for numerical stability #3741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,8 +2,9 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| import ignite.distributed as idist | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from ignite.exceptions import NotComputableError | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from ignite.metrics.metric import reinit__is_reduced | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| from ignite.metrics.regression._base import _BaseRegression | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -19,6 +20,11 @@ class PearsonCorrelation(_BaseRegression): | |||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| where :math:`A_j` is the ground truth and :math:`P_j` is the predicted value. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Internally uses `Welford's online algorithm <https://en.wikipedia.org/wiki/ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Algorithms_for_calculating_variance#Welford's_online_algorithm>`_ for numerically | ||||||||||||||||||||||||||||||||||||||||||||||||||
| stable computation, avoiding catastrophic cancellation that can occur with the | ||||||||||||||||||||||||||||||||||||||||||||||||||
| naive sum-of-squares formula in float32. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -54,7 +60,10 @@ class PearsonCorrelation(_BaseRegression): | |||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| .. testoutput:: | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| 0.9768688678741455 | ||||||||||||||||||||||||||||||||||||||||||||||||||
| 0.9768687504744322 | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| .. versionchanged:: 0.6.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Uses Welford's online algorithm for improved numerical stability. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -68,56 +77,106 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||
| self.eps = eps | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| _state_dict_all_req_keys = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_sum_of_y_preds", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_sum_of_ys", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_sum_of_y_pred_squares", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_sum_of_y_squares", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_sum_of_products", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_num_examples", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_mean_x", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_mean_y", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_m2_x", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_m2_y", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_cxy", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| @reinit__is_reduced | ||||||||||||||||||||||||||||||||||||||||||||||||||
| def reset(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._sum_of_y_preds = torch.tensor(0.0, device=self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._sum_of_ys = torch.tensor(0.0, device=self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._sum_of_y_pred_squares = torch.tensor(0.0, device=self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._sum_of_y_squares = torch.tensor(0.0, device=self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._sum_of_products = torch.tensor(0.0, device=self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._num_examples = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._mean_x = torch.tensor(0.0, dtype=torch.float64, device=self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._mean_y = torch.tensor(0.0, dtype=torch.float64, device=self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._m2_x = torch.tensor(0.0, dtype=torch.float64, device=self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._m2_y = torch.tensor(0.0, dtype=torch.float64, device=self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._cxy = torch.tensor(0.0, dtype=torch.float64, device=self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def _update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_pred, y = output[0].detach(), output[1].detach() | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._sum_of_y_preds += y_pred.sum().to(self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._sum_of_ys += y.sum().to(self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._sum_of_y_pred_squares += y_pred.square().sum().to(self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._sum_of_y_squares += y.square().sum().to(self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._sum_of_products += (y_pred * y).sum().to(self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._num_examples += y.shape[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| @sync_all_reduce( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_sum_of_y_preds", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_sum_of_ys", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_sum_of_y_pred_squares", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_sum_of_y_squares", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_sum_of_products", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "_num_examples", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y_pred = y_pred.to(dtype=torch.float64) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| y = y.to(dtype=torch.float64) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| n_b = y.shape[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| n_a = self._num_examples | ||||||||||||||||||||||||||||||||||||||||||||||||||
| n_ab = n_a + n_b | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| mean_x_b = y_pred.mean().to(self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| mean_y_b = y.mean().to(self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Within-batch second moments | ||||||||||||||||||||||||||||||||||||||||||||||||||
| dx_b = y_pred - mean_x_b | ||||||||||||||||||||||||||||||||||||||||||||||||||
| dy_b = y - mean_y_b | ||||||||||||||||||||||||||||||||||||||||||||||||||
| m2_x_b = dx_b.square().sum().to(self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| m2_y_b = dy_b.square().sum().to(self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| cxy_b = (dx_b * dy_b).sum().to(self._device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+121
to
+128
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| mean_x_b = y_pred.mean().to(self._device) | |
| mean_y_b = y.mean().to(self._device) | |
| # Within-batch second moments | |
| dx_b = y_pred - mean_x_b | |
| dy_b = y - mean_y_b | |
| m2_x_b = dx_b.square().sum().to(self._device) | |
| m2_y_b = dy_b.square().sum().to(self._device) | |
| cxy_b = (dx_b * dy_b).sum().to(self._device) | |
| mean_x_b = y_pred.mean() | |
| mean_y_b = y.mean() | |
| # Within-batch second moments | |
| dx_b = y_pred - mean_x_b | |
| dy_b = y - mean_y_b | |
| m2_x_b = dx_b.square().sum() | |
| m2_y_b = dy_b.square().sum() | |
| cxy_b = (dx_b * dy_b).sum() | |
| mean_x_b = mean_x_b.to(self._device) | |
| mean_y_b = mean_y_b.to(self._device) | |
| m2_x_b = m2_x_b.to(self._device) | |
| m2_y_b = m2_y_b.to(self._device) | |
| cxy_b = cxy_b.to(self._device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The metric state is initialized and updated using
torch.float64tensors unconditionally. This will raise on MPS (float64 is not supported there), andavailable_deviceincludesmpsfor this test suite. Consider selecting an internal accumulator dtype based on device capabilities (e.g., float64 normally, but float32 on MPS, similar to howSSIMhandles this), and use that dtype consistently inreset,_updatecasts, and the distributedstatetensor incompute.