Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metrics] Unification of regression #4166

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 37 additions & 23 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ The example below shows how to use a metric in your ``LightningModule``:
def __init__(self):
...
self.accuracy = pl.metrics.Accuracy()

def training_step(self, batch, batch_idx):
logits = self(x)
...
# log step metric
self.log('train_acc_step', self.accuracy(logits, y))
...

def training_epoch_end(self, outs):
# log epoch metric
self.log('train_acc_epoch', self.accuracy.compute())
Expand All @@ -57,15 +57,15 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v
This however is only true for metrics that inherit the base class ``Metric``,
and thus the functional metric API provides no support for in-built distributed synchronization
or reduction functions.


.. code-block:: python

def __init__(self):
...
self.train_acc = pl.metrics.Accuracy()
self.valid_acc = pl.metrics.Accuracy()

def training_step(self, batch, batch_idx):
logits = self(x)
...
Expand All @@ -91,17 +91,17 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)

# training step accuracy
batch_acc = train_accuracy(y_hat, y)

for x, y in valid_data:
y_hat = model(x)
valid_accuracy(y_hat, y)

# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()

# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()

Expand Down Expand Up @@ -212,6 +212,20 @@ ExplainedVariance
.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance
:noindex:


PSNR
~~~~

.. autoclass:: pytorch_lightning.metrics.regression.PSNR
:noindex:


SSIM
~~~~

.. autoclass:: pytorch_lightning.metrics.regression.SSIM
:noindex:

******************
Functional Metrics
******************
Expand Down Expand Up @@ -360,45 +374,45 @@ to_onehot [func]
Regression
----------

mae [func]
~~~~~~~~~~
explained_variance [func]
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.regression.mae
.. autofunction:: pytorch_lightning.metrics.functional.explained_variance
:noindex:


mse [func]
~~~~~~~~~~
mean_absolute_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.regression.mse
.. autofunction:: pytorch_lightning.metrics.functional.mean_absolute_error
:noindex:


psnr [func]
~~~~~~~~~~~
mean_squared_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.regression.psnr
.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_error
:noindex:


rmse [func]
psnr [func]
~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.regression.rmse
.. autofunction:: pytorch_lightning.metrics.functional.psnr
:noindex:


rmsle [func]
~~~~~~~~~~~~
mean_squared_log_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.regression.rmsle
.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error
:noindex:


ssim [func]
~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.regression.mae
.. autofunction:: pytorch_lightning.metrics.functional.ssim
:noindex:


Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@
MeanAbsoluteError,
MeanSquaredLogError,
ExplainedVariance,
PSNR,
SSIM,
)
15 changes: 7 additions & 8 deletions pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,13 @@
iou,
)
from pytorch_lightning.metrics.functional.nlp import bleu_score
from pytorch_lightning.metrics.functional.regression import (
mae,
mse,
psnr,
rmse,
rmsle,
ssim
)
from pytorch_lightning.metrics.functional.self_supervised import (
embedding_similarity
)
# TODO: unify metrics between class and functional, add below
from pytorch_lightning.metrics.functional.explained_variance import explained_variance
from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error
from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error
from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error
from pytorch_lightning.metrics.functional.psnr import psnr
from pytorch_lightning.metrics.functional.ssim import ssim
85 changes: 85 additions & 0 deletions pytorch_lightning/metrics/functional/explained_variance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, Tuple, Sequence

import torch
Borda marked this conversation as resolved.
Show resolved Hide resolved
from pytorch_lightning.metrics.utils import _check_same_shape


def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
_check_same_shape(preds, target)
return preds, target


def _explained_variance_compute(preds: torch.Tensor,
target: torch.Tensor,
multioutput: str = 'uniform_average',
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
diff_avg = torch.mean(target - preds, dim=0)
numerator = torch.mean((target - preds - diff_avg) ** 2, dim=0)

target_avg = torch.mean(target, dim=0)
denominator = torch.mean((target - target_avg) ** 2, dim=0)

# Take care of division by zero
nonzero_numerator = numerator != 0
nonzero_denominator = denominator != 0
valid_score = nonzero_numerator & nonzero_denominator
output_scores = torch.ones_like(diff_avg)
output_scores[valid_score] = 1.0 - (numerator[valid_score] / denominator[valid_score])
output_scores[nonzero_numerator & ~nonzero_denominator] = 0.

# Decide what to do in multioutput case
# Todo: allow user to pass in tensor with weights
if multioutput == 'raw_values':
return output_scores
if multioutput == 'uniform_average':
return torch.mean(output_scores)
if multioutput == 'variance_weighted':
denom_sum = torch.sum(denominator)
return torch.sum(denominator / denom_sum * output_scores)


def explained_variance(preds: torch.Tensor,
target: torch.Tensor,
multioutput: str = 'uniform_average',
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
"""
Computes explained variance.

Args:
pred: estimated labels
target: ground truth labels
multioutput: Defines aggregation in the case of multiple output scores. Can be one
of the following strings (default is `'uniform_average'`.):

* `'raw_values'` returns full set of scores
* `'uniform_average'` scores are uniformly averaged
* `'variance_weighted'` scores are weighted by their individual variances

Example:

>>> from pytorch_lightning.metrics.functional import explained_variance
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> explained_variance(preds, target)
tensor(0.9572)

>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> explained_variance(preds, target, multioutput='raw_values')
tensor([0.9677, 1.0000])
"""
preds, target = _explained_variance_update(preds, target)
return _explained_variance_compute(preds, target, multioutput)
51 changes: 51 additions & 0 deletions pytorch_lightning/metrics/functional/mean_absolute_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch
Borda marked this conversation as resolved.
Show resolved Hide resolved
from pytorch_lightning.metrics.utils import _check_same_shape


def _mean_absolute_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
_check_same_shape(preds, target)
sum_abs_error = torch.sum(torch.abs(preds - target))
n_obs = target.numel()
return sum_abs_error, n_obs


def _mean_absolute_error_compute(sum_abs_error: torch.Tensor, n_obs: int) -> torch.Tensor:
return sum_abs_error / n_obs


def mean_absolute_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Computes mean absolute error

Args:
pred: estimated labels
target: ground truth labels

Return:
Tensor with MAE

Example:

>>> x = torch.tensor([0., 1, 2, 3])
>>> y = torch.tensor([0., 1, 2, 2])
>>> mean_absolute_error(x, y)
tensor(0.2500)

"""
sum_abs_error, n_obs = _mean_absolute_error_update(preds, target)
return _mean_absolute_error_compute(sum_abs_error, n_obs)
51 changes: 51 additions & 0 deletions pytorch_lightning/metrics/functional/mean_squared_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch
Borda marked this conversation as resolved.
Show resolved Hide resolved
from pytorch_lightning.metrics.utils import _check_same_shape


def _mean_squared_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
_check_same_shape(preds, target)
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
n_obs = target.numel()
return sum_squared_error, n_obs


def _mean_squared_error_compute(sum_squared_error: torch.Tensor, n_obs: int) -> torch.Tensor:
return sum_squared_error / n_obs


def mean_squared_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Computes mean squared error

Args:
pred: estimated labels
target: ground truth labels

Return:
Tensor with MSE

Example:

>>> x = torch.tensor([0., 1, 2, 3])
>>> y = torch.tensor([0., 1, 2, 2])
>>> mean_squared_error(x, y)
tensor(0.2500)

"""
sum_squared_error, n_obs = _mean_squared_error_update(preds, target)
return _mean_squared_error_compute(sum_squared_error, n_obs)
Loading