From e00e3ab268acee7f7c1c89e67abffe07223d0cbb Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 28 Jul 2021 18:07:03 +0200 Subject: [PATCH] Fix OOM in pearson metric (#380) * fix tests * fix merge * fix flake8 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec --- CHANGELOG.md | 3 + tests/regression/test_pearson.py | 5 +- torchmetrics/functional/regression/pearson.py | 81 +++++++++++-------- torchmetrics/regression/pearson.py | 75 +++++++++++++---- 4 files changed, 111 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c022003bda7..d7ebbd263b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Moved `R2Score` from `regression.r2score` to `regression.r2` ([#371](https://github.com/PyTorchLightning/metrics/pull/371)) +- Pearson metrics now only store 6 statistics instead of all predictions and targets ([#380](https://github.com/PyTorchLightning/metrics/pull/380)) + + ### Deprecated - Rename `r2score` >> `r2_score` and `kldivergence` >> `kl_divergence` in `functional` ([#371](https://github.com/PyTorchLightning/metrics/pull/371)) diff --git a/tests/regression/test_pearson.py b/tests/regression/test_pearson.py index 6c17ab0293e..ecc744e119c 100644 --- a/tests/regression/test_pearson.py +++ b/tests/regression/test_pearson.py @@ -53,15 +53,14 @@ class TestPearsonCorrcoef(MetricTester): atol = 1e-2 @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_pearson_corrcoef(self, preds, target, ddp, dist_sync_on_step): + def test_pearson_corrcoef(self, preds, target, ddp): self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=PearsonCorrcoef, sk_metric=_sk_pearsonr, - dist_sync_on_step=dist_sync_on_step, + dist_sync_on_step=False, ) def test_pearson_corrcoef_functional(self, preds, target): diff --git a/torchmetrics/functional/regression/pearson.py b/torchmetrics/functional/regression/pearson.py index 0ff9e722859..3f48abd1a43 100644 --- a/torchmetrics/functional/regression/pearson.py +++ b/torchmetrics/functional/regression/pearson.py @@ -22,16 +22,25 @@ def _pearson_corrcoef_update( preds: Tensor, target: Tensor, -) -> Tuple[Tensor, Tensor]: + mean_x: Tensor, + mean_y: Tensor, + var_x: Tensor, + var_y: Tensor, + corr_xy: Tensor, + n_prior: Tensor, +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """ Updates and returns variables required to compute Pearson Correlation Coefficient. Checks for same shape of input tensors. Args: - preds: Predicted tensor - target: Ground truth tensor + mean_x: current mean estimate of x tensor + mean_y: current mean estimate of y tensor + var_x: current variance estimate of x tensor + var_y: current variance estimate of y tensor + corr_xy: current covariance estimate between x and y tensor + n_prior: current number of observed observations """ - # Data checking _check_same_shape(preds, target) preds = preds.squeeze() @@ -39,39 +48,40 @@ def _pearson_corrcoef_update( if preds.ndim > 1 or target.ndim > 1: raise ValueError('Expected both predictions and target to be 1 dimensional tensors.') - return preds, target - - -def _pearson_corrcoef_compute(preds: Tensor, target: Tensor, eps: float = 1e-6) -> Tensor: + n_obs = preds.numel() + mx_new = (n_prior * mean_x + preds.mean() * n_obs) / (n_prior + n_obs) + my_new = (n_prior * mean_y + target.mean() * n_obs) / (n_prior + n_obs) + n_prior += n_obs + var_x += ((preds - mx_new) * (preds - mean_x)).sum() + var_y += ((target - my_new) * (target - mean_y)).sum() + corr_xy += ((preds - mx_new) * (target - mean_y)).sum() + mean_x = mx_new + mean_y = my_new + + return mean_x, mean_y, var_x, var_y, corr_xy, n_prior + + +def _pearson_corrcoef_compute( + var_x: Tensor, + var_y: Tensor, + corr_xy: Tensor, + nb: Tensor, +) -> Tensor: """ - Computes Pearson Correlation Coefficient. + Computes the final pearson correlation based on accumulated statistics Args: - preds: Predicted tensor - target: Ground truth tensor - eps: Avoids ZeroDivisionError. default: 1e-6 - - Example: - >>> target = torch.tensor([3, -0.5, 2, 7]) - >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> preds, target = _pearson_corrcoef_update(preds, target) - >>> _pearson_corrcoef_compute(preds, target) - tensor(0.9849) - """ + var_x: variance estimate of x tensor + var_y: variance estimate of y tensor + corr_xy: covariance estimate between x and y tensor + nb: number of observations - preds_diff = preds - preds.mean() - target_diff = target - target.mean() - cov = (preds_diff * target_diff).mean() - preds_std = torch.sqrt((preds_diff * preds_diff).mean()) - target_std = torch.sqrt((target_diff * target_diff).mean()) - - denom = preds_std * target_std - # prevent division by zero - if denom == 0: - denom += eps - - corrcoef = cov / denom + """ + var_x /= (nb - 1) + var_y /= (nb - 1) + corr_xy /= (nb - 1) + corrcoef = (corr_xy / (var_x * var_y).sqrt()).squeeze() return torch.clamp(corrcoef, -1.0, 1.0) @@ -90,5 +100,8 @@ def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor: >>> pearson_corrcoef(preds, target) tensor(0.9849) """ - preds, target = _pearson_corrcoef_update(preds, target) - return _pearson_corrcoef_compute(preds, target) + _temp = torch.zeros(1, dtype=preds.dtype, device=preds.device) + mean_x, mean_y, var_x = _temp.clone(), _temp.clone(), _temp.clone() + var_y, corr_xy, nb = _temp.clone(), _temp.clone(), _temp.clone() + _, _, var_x, var_y, corr_xy, nb = _pearson_corrcoef_update(preds, target, mean_x, mean_y, var_x, var_y, corr_xy, nb) + return _pearson_corrcoef_compute(var_x, var_y, corr_xy, nb) diff --git a/torchmetrics/regression/pearson.py b/torchmetrics/regression/pearson.py index 59853843dc1..93ef4df342b 100644 --- a/torchmetrics/regression/pearson.py +++ b/torchmetrics/regression/pearson.py @@ -11,15 +11,45 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional +from typing import Any, List, Optional, Tuple import torch from torch import Tensor from torchmetrics.functional.regression.pearson import _pearson_corrcoef_compute, _pearson_corrcoef_update from torchmetrics.metric import Metric -from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.data import dim_zero_cat + + +def _final_aggregation( + means_x: Tensor, + means_y: Tensor, + vars_x: Tensor, + vars_y: Tensor, + corrs_xy: Tensor, + nbs: Tensor, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """ + Aggregate the statistics from multiple devices. Formula taken from here: + https://stackoverflow.com/questions/68395368/estimate-running-correlation-on-multiple-nodes + """ + # assert len(means_x) > 1 and len(means_y) > 1 and len(vars_x) > 1 and len(vars_y) > 1 and len(corrs_xy) > 1 + mx1, my1, vx1, vy1, cxy1, n1 = means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0] + for i in range(1, len(means_x)): + mx2, my2, vx2, vy2, cxy2, n2 = means_x[i], means_y[i], vars_x[i], vars_y[i], corrs_xy[i], nbs[i] + + nb = n1 + n2 + mean_x = (n1 * mx1 + n2 * mx2) / nb + mean_y = (n1 * my1 + n2 * my2) / nb + var_x = (1 / (n1 + n2 - 1) * ((n1 - 1) * vx1 + (n2 - 1) * vx2 + ((n1 * n2) / (n1 + n2)) * (mx1 - mx2)**2)) + var_y = (1 / (n1 + n2 - 1) * ((n1 - 1) * vy1 + (n2 - 1) * vy2 + ((n1 * n2) / (n1 + n2)) * (my1 - my2)**2)) + + corr1 = n1 * cxy1 + n1 * (mx1 - mean_x) * (my1 - mean_y) + corr2 = n2 * cxy2 + n2 * (mx2 - mean_x) * (my2 - mean_y) + corr_xy = (corr1 + corr2) / (n1 + n2) + + mx1, my1, vx1, vy1, cxy1, n1 = mean_x, mean_y, var_x, var_y, corr_xy, nb + + return var_x, var_y, corr_xy, nb class PearsonCorrcoef(Metric): @@ -58,6 +88,12 @@ class PearsonCorrcoef(Metric): """ preds: List[Tensor] target: List[Tensor] + mean_x: Tensor + mean_y: Tensor + var_x: Tensor + var_y: Tensor + corr_xy: Tensor + n_total: Tensor def __init__( self, @@ -71,13 +107,12 @@ def __init__( process_group=process_group, ) - rank_zero_warn( - 'Metric `PearsonCorrcoef` will save all targets and predictions in buffer.' - ' For large datasets this may lead to large memory footprint.' - ) - - self.add_state("preds", default=[], dist_reduce_fx="cat") - self.add_state("target", default=[], dist_reduce_fx="cat") + self.add_state("mean_x", default=torch.zeros(1), dist_reduce_fx=None) + self.add_state("mean_y", default=torch.zeros(1), dist_reduce_fx=None) + self.add_state("var_x", default=torch.zeros(1), dist_reduce_fx=None) + self.add_state("var_y", default=torch.zeros(1), dist_reduce_fx=None) + self.add_state("corr_xy", default=torch.zeros(1), dist_reduce_fx=None) + self.add_state("n_total", default=torch.zeros(1), dist_reduce_fx=None) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ @@ -87,17 +122,25 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore preds: Predictions from model target: Ground truth values """ - preds, target = _pearson_corrcoef_update(preds, target) - self.preds.append(preds) - self.target.append(target) + self.mean_x, self.mean_y, self.var_x, self.var_y, self.corr_xy, self.n_total = _pearson_corrcoef_update( + preds, target, self.mean_x, self.mean_y, self.var_x, self.var_y, self.corr_xy, self.n_total + ) def compute(self) -> Tensor: """ Computes pearson correlation coefficient over state. """ - preds = dim_zero_cat(self.preds) - target = dim_zero_cat(self.target) - return _pearson_corrcoef_compute(preds, target) + if self.mean_x.numel() > 1: # multiple devices, need further reduction + var_x, var_y, corr_xy, n_total = _final_aggregation( + self.mean_x, self.mean_y, self.var_x, self.var_y, self.corr_xy, self.n_total + ) + else: + var_x = self.var_x + var_y = self.var_y + corr_xy = self.corr_xy + n_total = self.n_total + + return _pearson_corrcoef_compute(var_x, var_y, corr_xy, n_total) @property def is_differentiable(self) -> bool: