-
Notifications
You must be signed in to change notification settings - Fork 411
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added Cosine Similarity metric (#305)
* Added Cosine Similarity metric * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Merge branch 'master' into master * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Nicki Skafte <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update deepsource * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Made Cosine Similarity a Regression Metric * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Some comments correction * Some aforementioned corrections and addition of tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update cosine_similarity.py * Some comments correction * fix doctest * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove none test case * doc string * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Jirka <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
- Loading branch information
1 parent
cac72af
commit 578c8f5
Showing
10 changed files
with
304 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from collections import namedtuple | ||
from functools import partial | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
from sklearn.metrics.pairwise import cosine_similarity as sk_cosine | ||
|
||
from tests.helpers import seed_all | ||
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester | ||
from torchmetrics.functional.regression.cosine_similarity import cosine_similarity | ||
from torchmetrics.regression.cosine_similarity import CosineSimilarity | ||
|
||
seed_all(42) | ||
|
||
num_targets = 5 | ||
|
||
Input = namedtuple('Input', ["preds", "target"]) | ||
|
||
_single_target_inputs = Input( | ||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE), | ||
target=torch.rand(NUM_BATCHES, BATCH_SIZE), | ||
) | ||
|
||
_multi_target_inputs = Input( | ||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), | ||
target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), | ||
) | ||
|
||
|
||
def _multi_target_sk_metric(preds, target, reduction, sk_fn=sk_cosine): | ||
sk_preds = preds.view(-1, num_targets).numpy() | ||
sk_target = target.view(-1, num_targets).numpy() | ||
result_array = sk_fn(sk_target, sk_preds) | ||
col = np.diagonal(result_array) | ||
sum = col.sum() | ||
if reduction == 'sum': | ||
to_return = sum | ||
elif reduction == 'mean': | ||
mean = sum / len(col) | ||
to_return = mean | ||
else: | ||
to_return = col | ||
return to_return | ||
|
||
|
||
def _single_target_sk_metric(preds, target, reduction, sk_fn=sk_cosine): | ||
sk_preds = preds.view(-1).numpy() | ||
sk_target = target.view(-1).numpy() | ||
result_array = sk_fn(np.expand_dims(sk_preds, axis=0), np.expand_dims(sk_target, axis=0)) | ||
col = np.diagonal(result_array) | ||
sum = col.sum() | ||
if reduction == 'sum': | ||
to_return = sum | ||
elif reduction == 'mean': | ||
mean = sum / len(col) | ||
to_return = mean | ||
else: | ||
to_return = col | ||
return to_return | ||
|
||
|
||
@pytest.mark.parametrize("reduction", ['sum', 'mean']) | ||
@pytest.mark.parametrize( | ||
"preds, target, sk_metric", | ||
[ | ||
(_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), | ||
(_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), | ||
], | ||
) | ||
class TestCosineSimilarity(MetricTester): | ||
|
||
@pytest.mark.parametrize("ddp", [True, False]) | ||
@pytest.mark.parametrize("dist_sync_on_step", [True, False]) | ||
def test_cosine_similarity(self, reduction, preds, target, sk_metric, ddp, dist_sync_on_step): | ||
self.run_class_metric_test( | ||
ddp, | ||
preds, | ||
target, | ||
CosineSimilarity, | ||
partial(sk_metric, reduction=reduction), | ||
dist_sync_on_step, | ||
metric_args=dict(reduction=reduction), | ||
) | ||
|
||
def test_cosine_similarity_functional(self, reduction, preds, target, sk_metric): | ||
self.run_functional_metric_test( | ||
preds, | ||
target, | ||
cosine_similarity, | ||
partial(sk_metric, reduction=reduction), | ||
metric_args=dict(reduction=reduction), | ||
) | ||
|
||
|
||
def test_error_on_different_shape(metric_class=CosineSimilarity): | ||
metric = metric_class() | ||
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): | ||
metric(torch.randn(100, ), torch.randn(50, )) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Tuple | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from torchmetrics.utilities.checks import _check_same_shape | ||
|
||
|
||
def _cosine_similarity_update( | ||
preds: Tensor, | ||
target: Tensor, | ||
) -> Tuple[Tensor, Tensor]: | ||
_check_same_shape(preds, target) | ||
preds = preds.float() | ||
target = target.float() | ||
|
||
return preds, target | ||
|
||
|
||
def _cosine_similarity_compute(preds: Tensor, target: Tensor, reduction='sum') -> Tensor: | ||
dot_product = (preds * target).sum(dim=-1) | ||
preds_norm = preds.norm(dim=-1) | ||
target_norm = target.norm(dim=-1) | ||
similarity = dot_product / (preds_norm * target_norm) | ||
reduction_mapping = {"sum": torch.sum, "mean": torch.mean, "none": lambda x: x} | ||
return reduction_mapping[reduction](similarity) | ||
|
||
|
||
def cosine_similarity(preds: Tensor, target: Tensor, reduction='sum') -> Tensor: | ||
r""" | ||
Computes the `Cosine Similarity <https://en.wikipedia.org/wiki/Cosine_similarity>`_ | ||
between targets and predictions: | ||
.. math:: | ||
cos_{sim}(x,y) = \frac{x \cdot y}{||x|| \cdot ||y||} = | ||
\frac{\sum_{i=1}^n x_i y_i}{\sqrt{\sum_{i=1}^n x_i^2}\sqrt{\sum_{i=1}^n y_i^2}} | ||
where :math:`y` is a tensor of target values, and :math:`x` is a tensor of predictions. | ||
Args: | ||
preds: Predicted tensor with shape ``(N,d)`` | ||
target: Ground truth tensor with shape ``(N,d)`` | ||
reduction: | ||
The method of reducing along the batch dimension using sum, mean or taking the individual scores | ||
Example: | ||
>>> from torchmetrics.functional.regression import cosine_similarity | ||
>>> target = torch.tensor([[1, 2, 3, 4], | ||
... [1, 2, 3, 4]]) | ||
>>> preds = torch.tensor([[1, 2, 3, 4], | ||
... [-1, -2, -3, -4]]) | ||
>>> cosine_similarity(preds, target, 'none') | ||
tensor([ 1.0000, -1.0000]) | ||
""" | ||
preds, target = _cosine_similarity_update(preds, target) | ||
return _cosine_similarity_compute(preds, target, reduction) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Callable, Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from torchmetrics.functional.regression.cosine_similarity import _cosine_similarity_compute, _cosine_similarity_update | ||
from torchmetrics.metric import Metric | ||
from torchmetrics.utilities.data import dim_zero_cat | ||
|
||
|
||
class CosineSimilarity(Metric): | ||
r""" | ||
Computes the `Cosine Similarity <https://en.wikipedia.org/wiki/Cosine_similarity>`_ | ||
between targets and predictions: | ||
.. math:: | ||
cos_{sim}(x,y) = \frac{x \cdot y}{||x|| \cdot ||y||} = | ||
\frac{\sum_{i=1}^n x_i y_i}{\sqrt{\sum_{i=1}^n x_i^2}\sqrt{\sum_{i=1}^n y_i^2}} | ||
where :math:`y` is a tensor of target values, and :math:`x` is a tensor of predictions. | ||
Forward accepts | ||
- ``preds`` (float tensor): ``(N,d)`` | ||
- ``target`` (float tensor): ``(N,d)`` | ||
Args: | ||
reduction : how to reduce over the batch dimension using 'sum', 'mean' or 'none' | ||
(taking the individual scores) | ||
compute_on_step: | ||
Forward only calls ``update()`` and return ``None`` if this is set to ``False``. | ||
dist_sync_on_step: | ||
Synchronize metric state across processes at each ``forward()`` | ||
before returning the value at the step. | ||
process_group: | ||
Specify the process group on which synchronization is called. | ||
default: ``None`` (which selects the entire world) | ||
dist_sync_fn: | ||
Callback that performs the allgather operation on the metric state. When ``None``, DDP | ||
will be used to perform the all gather. | ||
Example: | ||
>>> from torchmetrics import CosineSimilarity | ||
>>> target = torch.tensor([[0, 1], [1, 1]]) | ||
>>> preds = torch.tensor([[0, 1], [0, 1]]) | ||
>>> cosine_similarity = CosineSimilarity(reduction = 'mean') | ||
>>> cosine_similarity(preds, target) | ||
tensor(0.8536) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
reduction: str = 'sum', | ||
compute_on_step: bool = True, | ||
dist_sync_on_step: bool = False, | ||
process_group: Optional[Any] = None, | ||
dist_sync_fn: Callable = None, | ||
): | ||
super().__init__( | ||
compute_on_step=compute_on_step, | ||
dist_sync_on_step=dist_sync_on_step, | ||
process_group=process_group, | ||
dist_sync_fn=dist_sync_fn | ||
) | ||
|
||
self.add_state("preds", [], dist_reduce_fx="cat") | ||
self.add_state("target", [], dist_reduce_fx="cat") | ||
self.reduction = reduction | ||
|
||
def update(self, preds: Tensor, target: Tensor): | ||
""" | ||
Update metric states with predictions and targets. | ||
Args: | ||
preds: Predicted tensor with shape ``(N,d)`` | ||
target: Ground truth tensor with shape ``(N,d)`` | ||
""" | ||
preds, target = _cosine_similarity_update(preds, target) | ||
|
||
self.preds.append(preds) | ||
self.target.append(target) | ||
|
||
def compute(self): | ||
preds = dim_zero_cat(self.preds) | ||
target = dim_zero_cat(self.target) | ||
return _cosine_similarity_compute(preds, target, self.reduction) | ||
|
||
@property | ||
def is_differentiable(self) -> bool: | ||
return True |