Skip to content

Commit

Permalink
Added MinkowskiDistance support (#1362)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
6 people authored Feb 24, 2023
1 parent 606dc17 commit b5aaa60
Show file tree
Hide file tree
Showing 14 changed files with 428 additions and 8 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `classes` to output from `MAP` metric ([#1419](https://github.com/Lightning-AI/metrics/pull/1419))


- Added `MinkowskiDistance` to regression package ([#1362](https://github.com/Lightning-AI/metrics/pull/1362))


- Added `pairwise_minkowski_distance` to pairwise package ([#1362](https://github.com/Lightning-AI/metrics/pull/1362))


- Added new detection metric `PanopticQuality` ([#929](https://github.com/PyTorchLightning/metrics/pull/929))


Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,4 @@
.. _Multilabel coverage error: https://link.springer.com/chapter/10.1007/978-0-387-09823-4_34
.. _Panoptic Quality: https://arxiv.org/abs/1801.00868
.. _torchmetrics mAP example: https://github.com/Lightning-AI/metrics/blob/master/examples/detection_map.py
.. _Minkowski Distance: https://en.wikipedia.org/wiki/Minkowski_distance
14 changes: 14 additions & 0 deletions docs/source/pairwise/minkowski_distance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.. customcarditem::
:header: Pairwise Minkowski Distance
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/translation.svg
:tags: Pairwise

##################
Minkowski Distance
##################

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.pairwise_minkowski_distance
:noindex:
23 changes: 23 additions & 0 deletions docs/source/regression/minkowski_distance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Minkowski Distance
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Regression

.. include:: ../links.rst

##################
Minkowski Distance
##################

Module Interface
________________

.. autoclass:: torchmetrics.MinkowskiDistance
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.minkowski_distance
:noindex:
1 change: 1 addition & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
MeanAbsolutePercentageError,
MeanSquaredError,
MeanSquaredLogError,
MinkowskiDistance,
PearsonCorrCoef,
R2Score,
SpearmanCorrCoef,
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance
from torchmetrics.functional.pairwise.linear import pairwise_linear_similarity
from torchmetrics.functional.pairwise.manhattan import pairwise_manhattan_distance
from torchmetrics.functional.pairwise.minkowski import pairwise_minkowski_distance
from torchmetrics.functional.regression.concordance import concordance_corrcoef
from torchmetrics.functional.regression.cosine_similarity import cosine_similarity
from torchmetrics.functional.regression.explained_variance import explained_variance
Expand All @@ -64,6 +65,7 @@
from torchmetrics.functional.regression.log_mse import mean_squared_log_error
from torchmetrics.functional.regression.mae import mean_absolute_error
from torchmetrics.functional.regression.mape import mean_absolute_percentage_error
from torchmetrics.functional.regression.minkowski import minkowski_distance
from torchmetrics.functional.regression.mse import mean_squared_error
from torchmetrics.functional.regression.pearson import pearson_corrcoef
from torchmetrics.functional.regression.r2 import r2_score
Expand Down Expand Up @@ -134,11 +136,13 @@
"mean_absolute_percentage_error",
"mean_squared_error",
"mean_squared_log_error",
"minkowski_distance",
"multiscale_structural_similarity_index_measure",
"pairwise_cosine_similarity",
"pairwise_euclidean_distance",
"pairwise_linear_similarity",
"pairwise_manhattan_distance",
"pairwise_minkowski_distance",
"panoptic_quality",
"pearson_corrcoef",
"pearsons_contingency_coefficient",
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/functional/pairwise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance # noqa: F401
from torchmetrics.functional.pairwise.linear import pairwise_linear_similarity # noqa: F401
from torchmetrics.functional.pairwise.manhattan import pairwise_manhattan_distance # noqa: F401
from torchmetrics.functional.pairwise.minkowski import pairwise_minkowski_distance # noqa: F401
91 changes: 91 additions & 0 deletions src/torchmetrics/functional/pairwise/minkowski.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.pairwise.helpers import _check_input, _reduce_distance_matrix
from torchmetrics.utilities.exceptions import TorchMetricsUserError


def _pairwise_minkowski_distance_update(
x: Tensor, y: Optional[Tensor] = None, exponent: Union[int, float] = 2, zero_diagonal: Optional[bool] = None
) -> Tensor:
"""Calculate the pairwise minkowski distance matrix.
Args:
x: tensor of shape ``[N,d]``
y: tensor of shape ``[M,d]``
exponent: int or float larger than 1, exponent to which the difference between preds and target is to be raised
zero_diagonal: determines if the diagonal of the distance matrix should be set to zero
"""
x, y, zero_diagonal = _check_input(x, y, zero_diagonal)
if not (isinstance(exponent, (float, int)) and exponent >= 1):
raise TorchMetricsUserError(f"Argument ``p`` must be a float or int greater than 1, but got {exponent}")
# upcast to float64 to prevent precision issues
_orig_dtype = x.dtype
x = x.to(torch.float64)
y = y.to(torch.float64)
distance = (x.unsqueeze(1) - y.unsqueeze(0)).abs().pow(exponent).sum(-1).pow(1.0 / exponent)
if zero_diagonal:
distance.fill_diagonal_(0)
return distance.to(_orig_dtype)


def pairwise_minkowski_distance(
x: Tensor,
y: Optional[Tensor] = None,
exponent: Union[int, float] = 2,
reduction: Literal["mean", "sum", "none", None] = None,
zero_diagonal: Optional[bool] = None,
) -> Tensor:
r"""Calculate pairwise minkowski distances.
.. math::
d_{minkowski}(x,y,p) = ||x - y||_p = \sqrt[p]{\sum_{d=1}^D (x_d - y_d)^p}
If both :math:`x` and :math:`y` are passed in, the calculation will be performed pairwise between the rows of
:math:`x` and :math:`y`. If only :math:`x` is passed in, the calculation will be performed between the rows
of :math:`x`.
Args:
x: Tensor with shape ``[N, d]``
y: Tensor with shape ``[M, d]``, optional
exponent: int or float larger than 1, exponent to which the difference between preds and target is to be raised
reduction: reduction to apply along the last dimension. Choose between `'mean'`, `'sum'`
(applied along column dimension) or `'none'`, `None` for no reduction
zero_diagonal: if the diagonal of the distance matrix should be set to 0. If only `x` is given
this defaults to `True` else if `y` is also given it defaults to `False`
Returns:
A ``[N,N]`` matrix of distances if only ``x`` is given, else a ``[N,M]`` matrix
Example:
>>> import torch
>>> from torchmetrics.functional import pairwise_minkowski_distance
>>> x = torch.tensor([[2, 3], [3, 5], [5, 8]], dtype=torch.float32)
>>> y = torch.tensor([[1, 0], [2, 1]], dtype=torch.float32)
>>> pairwise_minkowski_distance(x, y, exponent=4)
tensor([[3.0092, 2.0000],
[5.0317, 4.0039],
[8.1222, 7.0583]])
>>> pairwise_minkowski_distance(x, exponent=4)
tensor([[0.0000, 2.0305, 5.1547],
[2.0305, 0.0000, 3.1383],
[5.1547, 3.1383, 0.0000]])
"""
distance = _pairwise_minkowski_distance_update(x, y, exponent, zero_diagonal)
return _reduce_distance_matrix(distance, reduction)
1 change: 1 addition & 0 deletions src/torchmetrics/functional/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.functional.regression.log_mse import mean_squared_log_error # noqa: F401
from torchmetrics.functional.regression.mae import mean_absolute_error # noqa: F401
from torchmetrics.functional.regression.mape import mean_absolute_percentage_error # noqa: F401
from torchmetrics.functional.regression.minkowski import minkowski_distance # noqa: F401
from torchmetrics.functional.regression.mse import mean_squared_error # noqa: F401
from torchmetrics.functional.regression.pearson import pearson_corrcoef # noqa: F401
from torchmetrics.functional.regression.r2 import r2_score # noqa: F401
Expand Down
83 changes: 83 additions & 0 deletions src/torchmetrics/functional/regression/minkowski.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.exceptions import TorchMetricsUserError


def _minkowski_distance_update(preds: Tensor, targets: Tensor, p: float) -> Tensor:
"""Update and return variables required to compute Minkowski distance.
Checks for same shape of input tensors.
Args:
preds: Predicted tensor
targets: Ground truth tensor
p: Non-negative number acting as the p to the errors
"""
_check_same_shape(preds, targets)

if not (isinstance(p, (float, int)) and p >= 1):
raise TorchMetricsUserError(f"Argument ``p`` must be a float or int greater than 1, but got {p}")

difference = torch.abs(preds - targets)
mink_dist_sum = torch.sum(torch.pow(difference, p))

return mink_dist_sum


def _minkowski_distance_compute(distance: Tensor, p: float) -> Tensor:
"""Compute Minkowski Distance.
Args:
distance: Sum of the p-th powers of errors over all observations
p: The non-negative numeric power the errors are to be raised to
Example:
>>> preds = torch.tensor([0., 1, 2, 3])
>>> target = torch.tensor([0., 2, 3, 1])
>>> distance_p_sum = _minkowski_distance_update(preds, target, 5)
>>> _minkowski_distance_compute(distance_p_sum, 5)
tensor(2.0244)
"""
return torch.pow(distance, 1.0 / p)


def minkowski_distance(preds: Tensor, targets: Tensor, p: float) -> Tensor:
r"""Compute the `Minkowski distance`_.
.. math:: d_{\text{Minkowski}} = \\sum_{i}^N (| y_i - \\hat{y_i} |^p)^\frac{1}{p}
This metric can be seen as generalized version of the standard euclidean distance which corresponds to minkowski
distance with p=2.
Args:
preds: estimated labels of type Tensor
targets: ground truth labels of type Tensor
p: int or float larger than 1, exponent to which the difference between preds and target is to be raised
Return:
Tensor with the Minkowski distance
Example:
>>> from torchmetrics.functional import minkowski_distance
>>> x = torch.tensor([1.0, 2.8, 3.5, 4.5])
>>> y = torch.tensor([6.1, 2.11, 3.1, 5.6])
>>> minkowski_distance(x, y, p=3)
tensor(5.1220)
"""
minkowski_dist_sum = _minkowski_distance_update(preds, targets, p)
return _minkowski_distance_compute(minkowski_dist_sum, p)
1 change: 1 addition & 0 deletions src/torchmetrics/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.regression.log_mse import MeanSquaredLogError # noqa: F401
from torchmetrics.regression.mae import MeanAbsoluteError # noqa: F401
from torchmetrics.regression.mape import MeanAbsolutePercentageError # noqa: F401
from torchmetrics.regression.minkowski import MinkowskiDistance # noqa: F401
from torchmetrics.regression.mse import MeanSquaredError # noqa: F401
from torchmetrics.regression.pearson import PearsonCorrCoef # noqa: F401
from torchmetrics.regression.r2 import R2Score # noqa: F401
Expand Down
70 changes: 70 additions & 0 deletions src/torchmetrics/regression/minkowski.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional

from torch import Tensor, tensor

from torchmetrics.functional.regression.minkowski import _minkowski_distance_compute, _minkowski_distance_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.exceptions import TorchMetricsUserError


class MinkowskiDistance(Metric):
r"""Compute `Minkowski Distance`_.
.. math:: d_{\text{Minkowski}} = \\sum_{i}^N (| y_i - \\hat{y_i} |^p)^\frac{1}{p}
This metric can be seen as generalized version of the standard euclidean distance which corresponds to minkowski
distance with p=2.
where
:math:`y` is a tensor of target values,
:math:`\\hat{y}` is a tensor of predictions,
:math: `\\p` is a non-negative integer or floating-point number
Args:
p: int or float larger than 1, exponent to which the difference between preds and target is to be raised
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
>>> from torchmetrics import MinkowskiDistance
>>> target = tensor([1.0, 2.8, 3.5, 4.5])
>>> preds = tensor([6.1, 2.11, 3.1, 5.6])
>>> minkowski_distance = MinkowskiDistance(3)
>>> minkowski_distance(preds, target)
tensor(5.1220)
"""

is_differentiable: Optional[bool] = True
higher_is_better: Optional[bool] = False
full_state_update: Optional[bool] = False
minkowski_dist_sum: Tensor

def __init__(self, p: float, **kwargs: Any) -> None:
super().__init__(**kwargs)
if not (isinstance(p, (float, int)) and p >= 1):
raise TorchMetricsUserError(f"Argument ``p`` must be a float or int greater than 1, but got {p}")

self.p = p
self.add_state("minkowski_dist_sum", default=tensor(0.0), dist_reduce_fx="sum")

def update(self, preds: Tensor, targets: Tensor) -> None:
"""Update state with predictions and targets."""
minkowski_dist_sum = _minkowski_distance_update(preds, targets, self.p)
self.minkowski_dist_sum += minkowski_dist_sum

def compute(self) -> Tensor:
"""Compute metric."""
return _minkowski_distance_compute(self.minkowski_dist_sum, self.p)
Loading

0 comments on commit b5aaa60

Please sign in to comment.