Skip to content

Commit

Permalink
lightweight explained variance (#68)
Browse files Browse the repository at this point in the history
* light explained variance

* fix

* fix

* update CHANGELOG.md

* Update CHANGELOG.md

Co-authored-by: thomasgaudelet <[email protected]>
Co-authored-by: Nicki Skafte <[email protected]>
  • Loading branch information
3 people authored Mar 15, 2021
1 parent 179117b commit 8002ddc
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 34 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))


### Deprecated

Expand Down
45 changes: 29 additions & 16 deletions torchmetrics/functional/regression/explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,58 @@
from torchmetrics.utilities.checks import _check_same_shape


def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def _explained_variance_update(
preds: torch.Tensor, target: torch.Tensor
) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
_check_same_shape(preds, target)
return preds, target

n_obs = preds.size(0)
sum_error = torch.sum(target - preds, dim=0)
sum_squared_error = torch.sum((target - preds) ** 2, dim=0)

sum_target = torch.sum(target, dim=0)
sum_squared_target = torch.sum(target ** 2, dim=0)

return n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target


def _explained_variance_compute(
preds: torch.Tensor,
target: torch.Tensor,
multioutput: str = 'uniform_average',
n_obs: torch.Tensor,
sum_error: torch.Tensor,
sum_squared_error: torch.Tensor,
sum_target: torch.Tensor,
sum_squared_target: torch.Tensor,
multioutput: str = "uniform_average",
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
diff_avg = torch.mean(target - preds, dim=0)
numerator = torch.mean((target - preds - diff_avg)**2, dim=0)
diff_avg = sum_error / n_obs
numerator = sum_squared_error / n_obs - diff_avg ** 2

target_avg = torch.mean(target, dim=0)
denominator = torch.mean((target - target_avg)**2, dim=0)
target_avg = sum_target / n_obs
denominator = sum_squared_target / n_obs - target_avg ** 2

# Take care of division by zero
nonzero_numerator = numerator != 0
nonzero_denominator = denominator != 0
valid_score = nonzero_numerator & nonzero_denominator
output_scores = torch.ones_like(diff_avg)
output_scores[valid_score] = 1.0 - (numerator[valid_score] / denominator[valid_score])
output_scores[nonzero_numerator & ~nonzero_denominator] = 0.
output_scores[nonzero_numerator & ~nonzero_denominator] = 0.0

# Decide what to do in multioutput case
# Todo: allow user to pass in tensor with weights
if multioutput == 'raw_values':
if multioutput == "raw_values":
return output_scores
if multioutput == 'uniform_average':
if multioutput == "uniform_average":
return torch.mean(output_scores)
if multioutput == 'variance_weighted':
if multioutput == "variance_weighted":
denom_sum = torch.sum(denominator)
return torch.sum(denominator / denom_sum * output_scores)


def explained_variance(
preds: torch.Tensor,
target: torch.Tensor,
multioutput: str = 'uniform_average',
multioutput: str = "uniform_average",
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
"""
Computes explained variance.
Expand Down Expand Up @@ -84,5 +97,5 @@ def explained_variance(
>>> explained_variance(preds, target, multioutput='raw_values')
tensor([0.9677, 1.0000])
"""
preds, target = _explained_variance_update(preds, target)
return _explained_variance_compute(preds, target, multioutput)
n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target = _explained_variance_update(preds, target)
return _explained_variance_compute(n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target, multioutput)
40 changes: 22 additions & 18 deletions torchmetrics/regression/explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
_explained_variance_update,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn


class ExplainedVariance(Metric):
Expand Down Expand Up @@ -77,7 +76,7 @@ class ExplainedVariance(Metric):

def __init__(
self,
multioutput: str = 'uniform_average',
multioutput: str = "uniform_average",
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
Expand All @@ -89,20 +88,17 @@ def __init__(
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted')
allowed_multioutput = ("raw_values", "uniform_average", "variance_weighted")
if multioutput not in allowed_multioutput:
raise ValueError(
f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}'
f"Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}"
)
self.multioutput = multioutput
self.add_state("y", default=[], dist_reduce_fx=None)
self.add_state("y_pred", default=[], dist_reduce_fx=None)

rank_zero_warn(
'Metric `ExplainedVariance` will save all targets and'
' predictions in buffer. For large datasets this may lead'
' to large memory footprint.'
)
self.add_state("sum_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_target", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_squared_target", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("n_obs", default=torch.tensor(0.0), dist_reduce_fx="sum")

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Expand All @@ -112,14 +108,22 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
preds: Predictions from model
target: Ground truth values
"""
preds, target = _explained_variance_update(preds, target)
self.y_pred.append(preds)
self.y.append(target)
n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target = _explained_variance_update(preds, target)
self.n_obs = self.n_obs + n_obs
self.sum_error = self.sum_error + sum_error
self.sum_squared_error = self.sum_squared_error + sum_squared_error
self.sum_target = self.sum_target + sum_target
self.sum_squared_target = self.sum_squared_target + sum_squared_target

def compute(self):
"""
Computes explained variance over state.
"""
preds = torch.cat(self.y_pred, dim=0)
target = torch.cat(self.y, dim=0)
return _explained_variance_compute(preds, target, self.multioutput)
return _explained_variance_compute(
self.n_obs,
self.sum_error,
self.sum_squared_error,
self.sum_target,
self.sum_squared_target,
self.multioutput,
)

0 comments on commit 8002ddc

Please sign in to comment.