Skip to content

Commit

Permalink
Fix OOM in pearson metric (#380)
Browse files Browse the repository at this point in the history
* fix tests

* fix merge

* fix flake8

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
5 people authored Jul 28, 2021
1 parent 7151833 commit e00e3ab
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 53 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Moved `R2Score` from `regression.r2score` to `regression.r2` ([#371](https://github.com/PyTorchLightning/metrics/pull/371))


- Pearson metrics now only store 6 statistics instead of all predictions and targets ([#380](https://github.com/PyTorchLightning/metrics/pull/380))


### Deprecated

- Rename `r2score` >> `r2_score` and `kldivergence` >> `kl_divergence` in `functional` ([#371](https://github.com/PyTorchLightning/metrics/pull/371))
Expand Down
5 changes: 2 additions & 3 deletions tests/regression/test_pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,14 @@ class TestPearsonCorrcoef(MetricTester):
atol = 1e-2

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_pearson_corrcoef(self, preds, target, ddp, dist_sync_on_step):
def test_pearson_corrcoef(self, preds, target, ddp):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=PearsonCorrcoef,
sk_metric=_sk_pearsonr,
dist_sync_on_step=dist_sync_on_step,
dist_sync_on_step=False,
)

def test_pearson_corrcoef_functional(self, preds, target):
Expand Down
81 changes: 47 additions & 34 deletions torchmetrics/functional/regression/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,56 +22,66 @@
def _pearson_corrcoef_update(
preds: Tensor,
target: Tensor,
) -> Tuple[Tensor, Tensor]:
mean_x: Tensor,
mean_y: Tensor,
var_x: Tensor,
var_y: Tensor,
corr_xy: Tensor,
n_prior: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""
Updates and returns variables required to compute Pearson Correlation Coefficient.
Checks for same shape of input tensors.
Args:
preds: Predicted tensor
target: Ground truth tensor
mean_x: current mean estimate of x tensor
mean_y: current mean estimate of y tensor
var_x: current variance estimate of x tensor
var_y: current variance estimate of y tensor
corr_xy: current covariance estimate between x and y tensor
n_prior: current number of observed observations
"""

# Data checking
_check_same_shape(preds, target)
preds = preds.squeeze()
target = target.squeeze()
if preds.ndim > 1 or target.ndim > 1:
raise ValueError('Expected both predictions and target to be 1 dimensional tensors.')

return preds, target


def _pearson_corrcoef_compute(preds: Tensor, target: Tensor, eps: float = 1e-6) -> Tensor:
n_obs = preds.numel()
mx_new = (n_prior * mean_x + preds.mean() * n_obs) / (n_prior + n_obs)
my_new = (n_prior * mean_y + target.mean() * n_obs) / (n_prior + n_obs)
n_prior += n_obs
var_x += ((preds - mx_new) * (preds - mean_x)).sum()
var_y += ((target - my_new) * (target - mean_y)).sum()
corr_xy += ((preds - mx_new) * (target - mean_y)).sum()
mean_x = mx_new
mean_y = my_new

return mean_x, mean_y, var_x, var_y, corr_xy, n_prior


def _pearson_corrcoef_compute(
var_x: Tensor,
var_y: Tensor,
corr_xy: Tensor,
nb: Tensor,
) -> Tensor:
"""
Computes Pearson Correlation Coefficient.
Computes the final pearson correlation based on accumulated statistics
Args:
preds: Predicted tensor
target: Ground truth tensor
eps: Avoids ZeroDivisionError. default: 1e-6
Example:
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> preds, target = _pearson_corrcoef_update(preds, target)
>>> _pearson_corrcoef_compute(preds, target)
tensor(0.9849)
"""
var_x: variance estimate of x tensor
var_y: variance estimate of y tensor
corr_xy: covariance estimate between x and y tensor
nb: number of observations
preds_diff = preds - preds.mean()
target_diff = target - target.mean()
cov = (preds_diff * target_diff).mean()
preds_std = torch.sqrt((preds_diff * preds_diff).mean())
target_std = torch.sqrt((target_diff * target_diff).mean())

denom = preds_std * target_std
# prevent division by zero
if denom == 0:
denom += eps

corrcoef = cov / denom
"""
var_x /= (nb - 1)
var_y /= (nb - 1)
corr_xy /= (nb - 1)
corrcoef = (corr_xy / (var_x * var_y).sqrt()).squeeze()
return torch.clamp(corrcoef, -1.0, 1.0)


Expand All @@ -90,5 +100,8 @@ def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor:
>>> pearson_corrcoef(preds, target)
tensor(0.9849)
"""
preds, target = _pearson_corrcoef_update(preds, target)
return _pearson_corrcoef_compute(preds, target)
_temp = torch.zeros(1, dtype=preds.dtype, device=preds.device)
mean_x, mean_y, var_x = _temp.clone(), _temp.clone(), _temp.clone()
var_y, corr_xy, nb = _temp.clone(), _temp.clone(), _temp.clone()
_, _, var_x, var_y, corr_xy, nb = _pearson_corrcoef_update(preds, target, mean_x, mean_y, var_x, var_y, corr_xy, nb)
return _pearson_corrcoef_compute(var_x, var_y, corr_xy, nb)
75 changes: 59 additions & 16 deletions torchmetrics/regression/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,45 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional
from typing import Any, List, Optional, Tuple

import torch
from torch import Tensor

from torchmetrics.functional.regression.pearson import _pearson_corrcoef_compute, _pearson_corrcoef_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat


def _final_aggregation(
means_x: Tensor,
means_y: Tensor,
vars_x: Tensor,
vars_y: Tensor,
corrs_xy: Tensor,
nbs: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Aggregate the statistics from multiple devices. Formula taken from here:
https://stackoverflow.com/questions/68395368/estimate-running-correlation-on-multiple-nodes
"""
# assert len(means_x) > 1 and len(means_y) > 1 and len(vars_x) > 1 and len(vars_y) > 1 and len(corrs_xy) > 1
mx1, my1, vx1, vy1, cxy1, n1 = means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
for i in range(1, len(means_x)):
mx2, my2, vx2, vy2, cxy2, n2 = means_x[i], means_y[i], vars_x[i], vars_y[i], corrs_xy[i], nbs[i]

nb = n1 + n2
mean_x = (n1 * mx1 + n2 * mx2) / nb
mean_y = (n1 * my1 + n2 * my2) / nb
var_x = (1 / (n1 + n2 - 1) * ((n1 - 1) * vx1 + (n2 - 1) * vx2 + ((n1 * n2) / (n1 + n2)) * (mx1 - mx2)**2))
var_y = (1 / (n1 + n2 - 1) * ((n1 - 1) * vy1 + (n2 - 1) * vy2 + ((n1 * n2) / (n1 + n2)) * (my1 - my2)**2))

corr1 = n1 * cxy1 + n1 * (mx1 - mean_x) * (my1 - mean_y)
corr2 = n2 * cxy2 + n2 * (mx2 - mean_x) * (my2 - mean_y)
corr_xy = (corr1 + corr2) / (n1 + n2)

mx1, my1, vx1, vy1, cxy1, n1 = mean_x, mean_y, var_x, var_y, corr_xy, nb

return var_x, var_y, corr_xy, nb


class PearsonCorrcoef(Metric):
Expand Down Expand Up @@ -58,6 +88,12 @@ class PearsonCorrcoef(Metric):
"""
preds: List[Tensor]
target: List[Tensor]
mean_x: Tensor
mean_y: Tensor
var_x: Tensor
var_y: Tensor
corr_xy: Tensor
n_total: Tensor

def __init__(
self,
Expand All @@ -71,13 +107,12 @@ def __init__(
process_group=process_group,
)

rank_zero_warn(
'Metric `PearsonCorrcoef` will save all targets and predictions in buffer.'
' For large datasets this may lead to large memory footprint.'
)

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
self.add_state("mean_x", default=torch.zeros(1), dist_reduce_fx=None)
self.add_state("mean_y", default=torch.zeros(1), dist_reduce_fx=None)
self.add_state("var_x", default=torch.zeros(1), dist_reduce_fx=None)
self.add_state("var_y", default=torch.zeros(1), dist_reduce_fx=None)
self.add_state("corr_xy", default=torch.zeros(1), dist_reduce_fx=None)
self.add_state("n_total", default=torch.zeros(1), dist_reduce_fx=None)

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Expand All @@ -87,17 +122,25 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
preds: Predictions from model
target: Ground truth values
"""
preds, target = _pearson_corrcoef_update(preds, target)
self.preds.append(preds)
self.target.append(target)
self.mean_x, self.mean_y, self.var_x, self.var_y, self.corr_xy, self.n_total = _pearson_corrcoef_update(
preds, target, self.mean_x, self.mean_y, self.var_x, self.var_y, self.corr_xy, self.n_total
)

def compute(self) -> Tensor:
"""
Computes pearson correlation coefficient over state.
"""
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _pearson_corrcoef_compute(preds, target)
if self.mean_x.numel() > 1: # multiple devices, need further reduction
var_x, var_y, corr_xy, n_total = _final_aggregation(
self.mean_x, self.mean_y, self.var_x, self.var_y, self.corr_xy, self.n_total
)
else:
var_x = self.var_x
var_y = self.var_y
corr_xy = self.corr_xy
n_total = self.n_total

return _pearson_corrcoef_compute(var_x, var_y, corr_xy, n_total)

@property
def is_differentiable(self) -> bool:
Expand Down

0 comments on commit e00e3ab

Please sign in to comment.