Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RelativeSquaredError #1765

Merged
merged 26 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b3ec758
Added functional and Module implementation of Relative Squared Error …
wbeardall May 8, 2023
22aded1
Updated docs
wbeardall May 8, 2023
55242a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2023
3581af6
Pre-commit check updates
wbeardall May 8, 2023
daddabf
Pre-commit check updates
wbeardall May 8, 2023
40431a7
Merge branch 'master' into master
Borda May 9, 2023
1f3aa69
Apply suggestions from code review
Borda May 9, 2023
d76ca38
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 9, 2023
8a6cf19
Merge branch 'master' into master
Borda May 9, 2023
190b88f
Apply suggestions from code review
Borda May 9, 2023
7c54e04
Merge branch 'master' into master
Borda May 9, 2023
dd50c25
Update src/torchmetrics/functional/regression/rse.py
justusschock May 9, 2023
7e21a87
changelog
SkafteNicki May 9, 2023
b215796
license
SkafteNicki May 9, 2023
80b2e96
Update src/torchmetrics/functional/regression/rse.py
SkafteNicki May 9, 2023
776e656
add plotting + docstring
SkafteNicki May 9, 2023
6439791
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 9, 2023
3ce09c5
fix typing
SkafteNicki May 9, 2023
4cefdf7
fix typing
SkafteNicki May 10, 2023
3304c41
Merge branch 'master' into master
SkafteNicki May 10, 2023
58e74b2
fix math
SkafteNicki May 10, 2023
025c27a
Merge branch 'master' into master
mergify[bot] May 10, 2023
ab6a9c8
fix broken test
SkafteNicki May 10, 2023
d44c169
Merge branch 'master' of https://github.com/wbeardall/torchmetrics in…
SkafteNicki May 10, 2023
3418385
try fix
SkafteNicki May 11, 2023
c50aa0f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `DistanceIntersectionOverUnion`


- Added `RelativeSquaredError` metric to regression package ([#1765](https://github.com/Lightning-AI/torchmetrics/pull/1765))


### Changed

- Changed `update_count` and `update_called` from private to public methods ([#1370](https://github.com/Lightning-AI/metrics/pull/1370))
Expand Down
23 changes: 23 additions & 0 deletions docs/source/regression/rse.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Relative Squared Error
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Regression

.. include:: ../links.rst

############################
Relative Squared Error (RSE)
############################

Module Interface
________________

.. autoclass:: torchmetrics.RelativeSquaredError
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.relative_squared_error
:noindex:
2 changes: 2 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
MinkowskiDistance,
PearsonCorrCoef,
R2Score,
RelativeSquaredError,
SpearmanCorrCoef,
SymmetricMeanAbsolutePercentageError,
TweedieDevianceScore,
Expand Down Expand Up @@ -190,6 +191,7 @@
"Recall",
"RecallAtFixedPrecision",
"RelativeAverageSpectralError",
"RelativeSquaredError",
"RetrievalFallOut",
"RetrievalHitRate",
"RetrievalMAP",
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
minkowski_distance,
pearson_corrcoef,
r2_score,
relative_squared_error,
spearman_corrcoef,
symmetric_mean_absolute_percentage_error,
tweedie_deviance_score,
Expand Down Expand Up @@ -190,6 +191,7 @@
"r2_score",
"recall",
"relative_average_spectral_error",
"relative_squared_error",
"retrieval_average_precision",
"retrieval_fall_out",
"retrieval_hit_rate",
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torchmetrics.functional.regression.mse import mean_squared_error
from torchmetrics.functional.regression.pearson import pearson_corrcoef
from torchmetrics.functional.regression.r2 import r2_score
from torchmetrics.functional.regression.rse import relative_squared_error
from torchmetrics.functional.regression.spearman import spearman_corrcoef
from torchmetrics.functional.regression.symmetric_mape import symmetric_mean_absolute_percentage_error
from torchmetrics.functional.regression.tweedie_deviance import tweedie_deviance_score
Expand All @@ -45,6 +46,7 @@
"mean_absolute_percentage_error",
"minkowski_distance",
"r2_score",
"relative_squared_error",
"spearman_corrcoef",
"symmetric_mean_absolute_percentage_error",
"tweedie_deviance_score",
Expand Down
78 changes: 78 additions & 0 deletions src/torchmetrics/functional/regression/rse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union

import torch
from torch import Tensor

from torchmetrics.functional.regression.r2 import _r2_score_update


def _relative_squared_error_compute(
sum_squared_obs: Tensor,
sum_obs: Tensor,
sum_squared_error: Tensor,
n_obs: Union[int, Tensor],
squared: bool = True,
) -> Tensor:
"""Computes Relative Squared Error.

Args:
sum_squared_obs: Sum of square of all observations
sum_obs: Sum of all observations
sum_squared_error: Residual sum of squares
n_obs: Number of predictions or observations
squared: Returns RRSE value if set to False.

Example:
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> # RSE uses the same update function as R2 score.
>>> sum_squared_obs, sum_obs, rss, n_obs = _r2_score_update(preds, target)
>>> _relative_squared_error_compute(sum_squared_obs, sum_obs, rss, n_obs, squared=True)
tensor(0.0632)
"""
epsilon = torch.finfo(sum_squared_error.dtype).eps
rse = sum_squared_error / torch.clamp(sum_squared_obs - sum_obs * sum_obs / n_obs, min=epsilon)
if not squared:
rse = torch.sqrt(rse)
return torch.mean(rse)


def relative_squared_error(preds: Tensor, target: Tensor, squared: bool = True) -> Tensor:
r"""Computes the relative squared error (RSE).

.. math:: \text{RSE} = \frac{\sum_i^N(y_i - \hat{y_i})^2}{\sum_i^N(y_i - \overline{y})^2}

Where :math:`y` is a tensor of target values with mean :math:`\overline{y}`, and
:math:`\hat{y}` is a tensor of predictions.

If `preds` and `targets` are 2D tensors, the RSE is averaged over the second dim.

Args:
preds: estimated labels
target: ground truth labels
squared: returns RRSE value if set to False
Return:
Tensor with RSE

Example:
>>> from torchmetrics.functional.regression import relative_squared_error
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> relative_squared_error(preds, target)
tensor(0.0514)
"""
sum_squared_obs, sum_obs, rss, n_obs = _r2_score_update(preds, target)
return _relative_squared_error_compute(sum_squared_obs, sum_obs, rss, n_obs, squared=squared)
2 changes: 2 additions & 0 deletions src/torchmetrics/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchmetrics.regression.mse import MeanSquaredError
from torchmetrics.regression.pearson import PearsonCorrCoef
from torchmetrics.regression.r2 import R2Score
from torchmetrics.regression.rse import RelativeSquaredError
from torchmetrics.regression.spearman import SpearmanCorrCoef
from torchmetrics.regression.symmetric_mape import SymmetricMeanAbsolutePercentageError
from torchmetrics.regression.tweedie_deviance import TweedieDevianceScore
Expand All @@ -43,6 +44,7 @@
"MeanSquaredError",
"PearsonCorrCoef",
"R2Score",
"RelativeSquaredError",
"SpearmanCorrCoef",
"SymmetricMeanAbsolutePercentageError",
"TweedieDevianceScore",
Expand Down
141 changes: 141 additions & 0 deletions src/torchmetrics/regression/rse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor, tensor

from torchmetrics.functional.regression.r2 import _r2_score_update
from torchmetrics.functional.regression.rse import _relative_squared_error_compute
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["RelativeSquaredError.plot"]


class RelativeSquaredError(Metric):
r"""Computes the relative squared error (RSE).

.. math:: \text{RSE} = \frac{\sum_i^N(y_i - \hat{y_i})^2}{\sum_i^N(y_i - \overline{y})^2}

Where :math:`y` is a tensor of target values with mean :math:`\overline{y}`, and
:math:`\hat{y}` is a tensor of predictions.

If num_outputs > 1, the returned value is averaged over all the outputs.

As input to ``forward`` and ``update`` the metric accepts the following input:

- ``preds`` (:class:`~torch.Tensor`): Predictions from model in float tensor with shape ``(N,)``
or ``(N, M)`` (multioutput)
- ``target`` (:class:`~torch.Tensor`): Ground truth values in float tensor with shape ``(N,)``
or ``(N, M)`` (multioutput)

As output of ``forward`` and ``compute`` the metric returns the following output:

- ``rse`` (:class:`~torch.Tensor`): A tensor with the RSE score(s)

Args:
num_outputs: Number of outputs in multioutput setting
squared: If True returns RSE value, if False returns RRSE value.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> from torchmetrics.regression import RelativeSquaredError
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> relative_squared_error = RelativeSquaredError()
>>> relative_squared_error(preds, target)
tensor(0.0514)
"""
is_differentiable = True
higher_is_better = False
full_state_update = False
sum_squared_error: Tensor
sum_error: Tensor
residual: Tensor
total: Tensor

def __init__(
self,
num_outputs: int = 1,
squared: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

self.num_outputs = num_outputs

self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
self.squared = squared

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
sum_squared_error, sum_error, residual, total = _r2_score_update(preds, target)

self.sum_squared_error += sum_squared_error
self.sum_error += sum_error
self.residual += residual
self.total += total

def compute(self) -> Tensor:
"""Computes relative squared error over state."""
return _relative_squared_error_compute(
self.sum_squared_error, self.sum_error, self.residual, self.total, squared=self.squared
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import RelativeSquaredError
>>> metric = RelativeSquaredError()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import RelativeSquaredError
>>> metric = RelativeSquaredError()
>>> values = []
>>> for _ in range(10):
... values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
"""
return self._plot(val, ax)
Loading