Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lightweight explained variance #68

Merged
merged 8 commits into from
Mar 15, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))


### Deprecated

Expand Down
45 changes: 29 additions & 16 deletions torchmetrics/functional/regression/explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,58 @@
from torchmetrics.utilities.checks import _check_same_shape


def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def _explained_variance_update(
preds: torch.Tensor, target: torch.Tensor
) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
_check_same_shape(preds, target)
return preds, target

n_obs = preds.size(0)
sum_error = torch.sum(target - preds, dim=0)
sum_squared_error = torch.sum((target - preds) ** 2, dim=0)

sum_target = torch.sum(target, dim=0)
sum_squared_target = torch.sum(target ** 2, dim=0)

return n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target


def _explained_variance_compute(
preds: torch.Tensor,
target: torch.Tensor,
multioutput: str = 'uniform_average',
n_obs: torch.Tensor,
sum_error: torch.Tensor,
sum_squared_error: torch.Tensor,
sum_target: torch.Tensor,
sum_squared_target: torch.Tensor,
multioutput: str = "uniform_average",
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
diff_avg = torch.mean(target - preds, dim=0)
numerator = torch.mean((target - preds - diff_avg)**2, dim=0)
diff_avg = sum_error / n_obs
numerator = sum_squared_error / n_obs - diff_avg ** 2

target_avg = torch.mean(target, dim=0)
denominator = torch.mean((target - target_avg)**2, dim=0)
target_avg = sum_target / n_obs
denominator = sum_squared_target / n_obs - target_avg ** 2

# Take care of division by zero
nonzero_numerator = numerator != 0
nonzero_denominator = denominator != 0
valid_score = nonzero_numerator & nonzero_denominator
output_scores = torch.ones_like(diff_avg)
output_scores[valid_score] = 1.0 - (numerator[valid_score] / denominator[valid_score])
output_scores[nonzero_numerator & ~nonzero_denominator] = 0.
output_scores[nonzero_numerator & ~nonzero_denominator] = 0.0

# Decide what to do in multioutput case
# Todo: allow user to pass in tensor with weights
if multioutput == 'raw_values':
if multioutput == "raw_values":
return output_scores
if multioutput == 'uniform_average':
if multioutput == "uniform_average":
return torch.mean(output_scores)
if multioutput == 'variance_weighted':
if multioutput == "variance_weighted":
denom_sum = torch.sum(denominator)
return torch.sum(denominator / denom_sum * output_scores)


def explained_variance(
preds: torch.Tensor,
target: torch.Tensor,
multioutput: str = 'uniform_average',
multioutput: str = "uniform_average",
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
"""
Computes explained variance.
Expand Down Expand Up @@ -84,5 +97,5 @@ def explained_variance(
>>> explained_variance(preds, target, multioutput='raw_values')
tensor([0.9677, 1.0000])
"""
preds, target = _explained_variance_update(preds, target)
return _explained_variance_compute(preds, target, multioutput)
n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target = _explained_variance_update(preds, target)
return _explained_variance_compute(n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target, multioutput)
40 changes: 22 additions & 18 deletions torchmetrics/regression/explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
_explained_variance_update,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn


class ExplainedVariance(Metric):
Expand Down Expand Up @@ -77,7 +76,7 @@ class ExplainedVariance(Metric):

def __init__(
self,
multioutput: str = 'uniform_average',
multioutput: str = "uniform_average",
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
Expand All @@ -89,20 +88,17 @@ def __init__(
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted')
allowed_multioutput = ("raw_values", "uniform_average", "variance_weighted")
if multioutput not in allowed_multioutput:
raise ValueError(
f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}'
f"Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}"
)
self.multioutput = multioutput
self.add_state("y", default=[], dist_reduce_fx=None)
self.add_state("y_pred", default=[], dist_reduce_fx=None)

rank_zero_warn(
'Metric `ExplainedVariance` will save all targets and'
' predictions in buffer. For large datasets this may lead'
' to large memory footprint.'
)
self.add_state("sum_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_target", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_squared_target", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("n_obs", default=torch.tensor(0.0), dist_reduce_fx="sum")

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Expand All @@ -112,14 +108,22 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
preds: Predictions from model
target: Ground truth values
"""
preds, target = _explained_variance_update(preds, target)
self.y_pred.append(preds)
self.y.append(target)
n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target = _explained_variance_update(preds, target)
self.n_obs = self.n_obs + n_obs
self.sum_error = self.sum_error + sum_error
self.sum_squared_error = self.sum_squared_error + sum_squared_error
self.sum_target = self.sum_target + sum_target
self.sum_squared_target = self.sum_squared_target + sum_squared_target
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

def compute(self):
"""
Computes explained variance over state.
"""
preds = torch.cat(self.y_pred, dim=0)
target = torch.cat(self.y, dim=0)
return _explained_variance_compute(preds, target, self.multioutput)
return _explained_variance_compute(
self.n_obs,
self.sum_error,
self.sum_squared_error,
self.sum_target,
self.sum_squared_target,
self.multioutput,
)