diff --git a/CHANGELOG.md b/CHANGELOG.md index 94a52a9a6fa..b23ce58d355 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Added `MutualInformationScore` metric to cluster package ([#2008](https://github.com/Lightning-AI/torchmetrics/pull/2008) ### Changed @@ -26,7 +26,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed bug in `PearsonCorrCoef` is updated on single samples at a time ([#2019](https://github.com/Lightning-AI/torchmetrics/pull/2019) + + +- Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017) ## [1.1.0] - 2023-08-22 diff --git a/docs/source/_static/images/logo.png b/docs/source/_static/images/logo.png index 6849a229b32..5a0e2da0436 100644 Binary files a/docs/source/_static/images/logo.png and b/docs/source/_static/images/logo.png differ diff --git a/docs/source/_static/images/logo.svg b/docs/source/_static/images/logo.svg index eb6e01bda27..10976a31ea9 100644 --- a/docs/source/_static/images/logo.svg +++ b/docs/source/_static/images/logo.svg @@ -1,20 +1,21 @@ - - - - - - - - + + + + + + + + - + + - + - - + + diff --git a/docs/source/clustering/mutual_info_score.rst b/docs/source/clustering/mutual_info_score.rst new file mode 100644 index 00000000000..e5adf06eaa9 --- /dev/null +++ b/docs/source/clustering/mutual_info_score.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Mutual Information Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/clustering.svg + :tags: Clustering + +.. include:: ../links.rst + +######################## +Mutual Information Score +######################## + +Module Interface +________________ + +.. autoclass:: torchmetrics.clustering.MutualInfoScore + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.clustering.mutual_info_score diff --git a/docs/source/index.rst b/docs/source/index.rst index 9da8cf0a51a..af7b6ff798b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -158,6 +158,14 @@ Or directly from conda classification/* +.. toctree:: + :maxdepth: 2 + :name: clustering + :caption: Clustering + :glob: + + clustering/* + .. toctree:: :maxdepth: 2 :name: detection diff --git a/docs/source/links.rst b/docs/source/links.rst index 4ca837ccd64..7627490c661 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -150,4 +150,5 @@ .. _CIOU: https://arxiv.org/abs/2005.03572 .. _DIOU: https://arxiv.org/abs/1911.08287v1 .. _GIOU: https://arxiv.org/abs/1902.09630 +.. _Mutual Information Score: https://en.wikipedia.org/wiki/Mutual_information .. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools diff --git a/requirements.txt b/requirements.txt index 27b5d0d3feb..536c920e6f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,4 @@ numpy >1.20.0 torch >=1.8.1, <=2.0.1 typing-extensions; python_version < '3.9' -packaging # hotfix for utils, can be dropped with lit-utils >=0.5 lightning-utilities >=0.8.0, <0.10.0 diff --git a/src/torchmetrics/clustering/__init__.py b/src/torchmetrics/clustering/__init__.py new file mode 100644 index 00000000000..baeb8c88d31 --- /dev/null +++ b/src/torchmetrics/clustering/__init__.py @@ -0,0 +1,18 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.clustering.mutual_info_score import MutualInfoScore + +__all__ = [ + "MutualInfoScore", +] diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py new file mode 100644 index 00000000000..86118daf41c --- /dev/null +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -0,0 +1,125 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Sequence, Union + +from torch import Tensor + +from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MutualInfoScore.plot"] + + +class MutualInfoScore(Metric): + r"""Compute `Mutual Information Score`_. + + .. math:: + MI(U,V) = \sum_{i=1}^{\abs{U}} \sum_{j=1}^{\abs{V}} \frac{\abs{U_i\cap V_j}}{N} + \log\frac{N\abs{U_i\cap V_j}}{\abs{U_i}\abs{V_j}} + + Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions, + :math:`\abs{U_i}` is the number of samples in cluster :math:`U_i`, and + :math:`\abs{V_i}` is the number of samples in cluster :math:`V_i`. + + The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields + the same mutual information score. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)`` + - ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)`` + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Mutual Information Score + + Args: + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> from torchmetrics.clustering import MutualInfoScore + >>> preds = torch.tensor([2, 1, 0, 1, 0]) + >>> target = torch.tensor([0, 2, 1, 1, 0]) + >>> mi_score = MutualInfoScore() + >>> mi_score(preds, target) + tensor(0.5004) + + """ + + is_differentiable = True + higher_is_better = None + full_state_update: bool = True + plot_lower_bound: float = 0.0 + preds: List[Tensor] + target: List[Tensor] + contingency: Tensor + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.preds.append(preds) + self.target.append(target) + + def compute(self) -> Tensor: + """Compute mutual information over state.""" + return mutual_info_score(dim_zero_cat(self.preds), dim_zero_cat(self.target)) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.clustering import MutualInfoScore + >>> metric = MutualInfoScore() + >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.clustering import MutualInfoScore + >>> metric = MutualInfoScore() + >>> for _ in range(10): + ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py new file mode 100644 index 00000000000..c6f46126ca3 --- /dev/null +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -0,0 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score + +__all__ = ["mutual_info_score"] diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py new file mode 100644 index 00000000000..f7c7cbfa587 --- /dev/null +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -0,0 +1,79 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import Tensor, tensor + +from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels + + +def _mutual_info_score_update(preds: Tensor, target: Tensor) -> Tensor: + """Update and return variables required to compute the mutual information score. + + Args: + preds: predicted class labels + target: ground truth class labels + + Returns: + contingency: contingency matrix + + """ + check_cluster_labels(preds, target) + return calculate_contingency_matrix(preds, target) + + +def _mutual_info_score_compute(contingency: Tensor) -> Tensor: + """Compute the mutual information score based on the contingency matrix. + + Args: + contingency: contingency matrix + + Returns: + mutual_info: mutual information score + + """ + n = contingency.sum() + u = contingency.sum(dim=1) + v = contingency.sum(dim=0) + + # Check if preds or target labels only have one cluster + if u.size() == 1 or v.size() == 1: + return tensor(0.0) + + # Find indices of nonzero values in U and V + nzu, nzv = torch.nonzero(contingency, as_tuple=True) + contingency = contingency[nzu, nzv] + + # Calculate MI using entries corresponding to nonzero contingency matrix entries + log_outer = torch.log(u[nzu]) + torch.log(v[nzv]) + mutual_info = contingency / n * (torch.log(n) + torch.log(contingency) - log_outer) + return mutual_info.sum() + + +def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor: + """Compute mutual information between two clusterings. + + Args: + preds: predicted classes + target: ground truth classes + + Example: + >>> from torchmetrics.functional.clustering import mutual_info_score + >>> target = torch.tensor([0, 3, 2, 2, 1]) + >>> preds = torch.tensor([1, 3, 2, 0, 1]) + >>> mutual_info_score(preds, target) + tensor(1.0549) + + """ + contingency = _mutual_info_score_update(preds, target) + return _mutual_info_score_compute(contingency) diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py new file mode 100644 index 00000000000..64dff0377ee --- /dev/null +++ b/src/torchmetrics/functional/clustering/utils.py @@ -0,0 +1,101 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torch import Tensor + +from torchmetrics.utilities.checks import _check_same_shape + + +def calculate_contingency_matrix( + preds: Tensor, target: Tensor, eps: Optional[float] = None, sparse: bool = False +) -> Tensor: + """Calculate contingency matrix. + + Args: + preds: predicted labels + target: ground truth labels + eps: value added to contingency matrix + sparse: If True, returns contingency matrix as a sparse matrix. Else, return as dense matrix. + `eps` must be `None` if `sparse` is `True`. + + Returns: + contingency: contingency matrix of shape (n_classes_target, n_classes_preds) + + Example: + >>> import torch + >>> from torchmetrics.functional.clustering.utils import calculate_contingency_matrix + >>> preds = torch.tensor([2, 1, 0, 1, 0]) + >>> target = torch.tensor([0, 2, 1, 1, 0]) + >>> calculate_contingency_matrix(preds, target, eps=1e-16) + tensor([[1.0000e+00, 1.0000e-16, 1.0000e+00], + [1.0000e+00, 1.0000e+00, 1.0000e-16], + [1.0000e-16, 1.0000e+00, 1.0000e-16]]) + + """ + if eps is not None and sparse is True: + raise ValueError("Cannot specify `eps` and return sparse tensor.") + if preds.ndim != 1 or target.ndim != 1: + raise ValueError(f"Expected 1d `preds` and `target` but got {preds.ndim} and {target.dim}.") + + preds_classes, preds_idx = torch.unique(preds, return_inverse=True) + target_classes, target_idx = torch.unique(target, return_inverse=True) + + n_classes_preds = preds_classes.size(0) + n_classes_target = target_classes.size(0) + + contingency = torch.sparse_coo_tensor( + torch.stack( + ( + target_idx, + preds_idx, + ) + ), + torch.ones(target_idx.shape[0], dtype=preds_idx.dtype, device=preds_idx.device), + ( + n_classes_target, + n_classes_preds, + ), + ) + + if not sparse: + contingency = contingency.to_dense() + if eps: + contingency = contingency + eps + + return contingency + + +def check_cluster_labels(preds: Tensor, target: Tensor) -> None: + """Check shape of input tensors and if they are real, discrete tensors. + + Args: + preds: predicted labels + target: ground truth labels + + """ + _check_same_shape(preds, target) + if preds.ndim != 1: + raise ValueError(f"Expected arguments to be 1d tensors but got {preds.ndim} and {target.ndim}") + if ( + torch.is_floating_point(preds) + or torch.is_complex(preds) + or torch.is_floating_point(target) + or torch.is_complex(target) + ): + raise ValueError( + f"Expected real, discrete values but received {preds.dtype} for" + f"predictions and {target.dtype} for target labels instead." + ) diff --git a/src/torchmetrics/functional/regression/mse.py b/src/torchmetrics/functional/regression/mse.py index fa20a3c96b6..c7d6d47dbfe 100644 --- a/src/torchmetrics/functional/regression/mse.py +++ b/src/torchmetrics/functional/regression/mse.py @@ -16,7 +16,6 @@ import torch from torch import Tensor -from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs from torchmetrics.utilities.checks import _check_same_shape @@ -32,7 +31,6 @@ def _mean_squared_error_update(preds: Tensor, target: Tensor, num_outputs: int) """ _check_same_shape(preds, target) - _check_data_shape_to_num_outputs(preds, target, num_outputs, allow_1d_reshape=True) if num_outputs == 1: preds = preds.view(-1) target = target.view(-1) diff --git a/src/torchmetrics/functional/regression/pearson.py b/src/torchmetrics/functional/regression/pearson.py index 8c8a4896a38..3c3d4dfc2c9 100644 --- a/src/torchmetrics/functional/regression/pearson.py +++ b/src/torchmetrics/functional/regression/pearson.py @@ -52,9 +52,9 @@ def _pearson_corrcoef_update( # Data checking _check_same_shape(preds, target) _check_data_shape_to_num_outputs(preds, target, num_outputs) - cond = n_prior.mean() > 0 - n_obs = preds.shape[0] + cond = n_prior.mean() > 0 or n_obs == 1 + if cond: mx_new = (n_prior * mean_x + preds.sum(0)) / (n_prior + n_obs) my_new = (n_prior * mean_y + target.sum(0)) / (n_prior + n_obs) @@ -67,7 +67,6 @@ def _pearson_corrcoef_update( if cond: var_x += ((preds - mx_new) * (preds - mean_x)).sum(0) var_y += ((target - my_new) * (target - mean_y)).sum(0) - else: var_x += preds.var(0) * (n_obs - 1) var_y += target.var(0) * (n_obs - 1) diff --git a/tests/unittests/clustering/__init__.py b/tests/unittests/clustering/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py new file mode 100644 index 00000000000..c4e0e56f38e --- /dev/null +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -0,0 +1,104 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple + +import pytest +import torch +from sklearn.metrics import mutual_info_score as sklearn_mutual_info_score +from torchmetrics.clustering.mutual_info_score import MutualInfoScore +from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score + +from unittests import BATCH_SIZE, NUM_BATCHES +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +seed_all(42) + +Input = namedtuple("Input", ["preds", "target"]) +NUM_CLASSES = 10 + +_single_target_inputs1 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_single_target_inputs2 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_float_inputs = Input( + preds=torch.rand((NUM_BATCHES, BATCH_SIZE)), + target=torch.rand((NUM_BATCHES, BATCH_SIZE)), +) + + +@pytest.mark.parametrize( + "preds, target", + [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs2.preds, _single_target_inputs2.target), + ], +) +class TestMutualInfoScore(MetricTester): + """Test class for `MutualInfoScore` metric.""" + + atol = 1e-5 + + @pytest.mark.parametrize("ddp", [True, False]) + def test_mutual_info_score(self, preds, target, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MutualInfoScore, + reference_metric=sklearn_mutual_info_score, + ) + + def test_mutual_info_score_functional(self, preds, target): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=mutual_info_score, + reference_metric=sklearn_mutual_info_score, + ) + + +def test_mutual_info_score_functional_single_cluster(): + """Check that for single cluster the metric returns 0.""" + tensor_a = torch.randint(NUM_CLASSES, (BATCH_SIZE,)) + tensor_b = torch.zeros(BATCH_SIZE, dtype=torch.int) + assert torch.allclose(mutual_info_score(tensor_a, tensor_b), torch.tensor(0.0)) + assert torch.allclose(mutual_info_score(tensor_b, tensor_a), torch.tensor(0.0)) + + +def test_mutual_info_score_functional_raises_invalid_task(): + """Check that metric rejects continuous-valued inputs.""" + preds, target = _float_inputs + with pytest.raises(ValueError, match=r"Expected *"): + mutual_info_score(preds, target) + + +@pytest.mark.parametrize( + ("preds", "target"), + [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + ], +) +def test_mutual_info_score_functional_is_symmetric(preds, target): + """Check that the metric funtional is symmetric.""" + for p, t in zip(preds, target): + assert torch.allclose(mutual_info_score(p, t), mutual_info_score(t, p)) diff --git a/tests/unittests/clustering/test_utils.py b/tests/unittests/clustering/test_utils.py new file mode 100644 index 00000000000..95ee1a6a4a7 --- /dev/null +++ b/tests/unittests/clustering/test_utils.py @@ -0,0 +1,78 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple + +import numpy as np +import pytest +import torch +from sklearn.metrics.cluster import contingency_matrix as sklearn_contingency_matrix +from torchmetrics.functional.clustering.utils import calculate_contingency_matrix + +from unittests import BATCH_SIZE +from unittests.helpers import seed_all + +seed_all(42) + +Input = namedtuple("Input", ["preds", "target"]) +NUM_CLASSES = 10 + +_sklearn_inputs = Input( + preds=torch.tensor([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]), + target=torch.tensor([1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 3, 1, 3, 3, 3, 2, 2]), +) + +_single_dim_inputs = Input( + preds=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE,)), + target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE,)), +) + +_multi_dim_inputs = Input( + preds=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE, 2)), + target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE, 2)), +) + + +@pytest.mark.parametrize( + ("preds", "target"), + [(_sklearn_inputs.preds, _sklearn_inputs.target), (_single_dim_inputs.preds, _single_dim_inputs.target)], +) +class TestContingencyMatrix: + """Test calculation of dense and sparse contingency matrices.""" + + atol = 1e-8 + + @pytest.mark.parametrize("eps", [None, 1e-16]) + def test_contingency_matrix_dense(self, preds, target, eps): + """Check that dense contingency matrices are calculated correctly.""" + tm_c = calculate_contingency_matrix(preds, target, eps) + sklearn_c = sklearn_contingency_matrix(target, preds, eps=eps) + assert np.allclose(tm_c, sklearn_c, atol=self.atol) + + def test_contingency_matrix_sparse(self, preds, target): + """Check that sparse contingency matrices are calculated correctly.""" + tm_c = calculate_contingency_matrix(preds, target, sparse=True).to_dense().numpy() + sklearn_c = sklearn_contingency_matrix(target, preds, sparse=True).toarray() + assert np.allclose(tm_c, sklearn_c, atol=self.atol) + + +def test_eps_and_sparse_error(): + """Check that contingency matrix is not calculated if `eps` is nonzero and `sparse` is True.""" + with pytest.raises(ValueError, match="Cannot specify*"): + calculate_contingency_matrix(_single_dim_inputs.preds, _single_dim_inputs.target, eps=1e-16, sparse=True) + + +def test_multidimensional_contingency_error(): + """Check that contingency matrix is not calculated for multidimensional input.""" + with pytest.raises(ValueError, match="Expected 1d*"): + calculate_contingency_matrix(_multi_dim_inputs.preds, _multi_dim_inputs.target) diff --git a/tests/unittests/regression/test_pearson.py b/tests/unittests/regression/test_pearson.py index 043ca470d5c..90c7df76b92 100644 --- a/tests/unittests/regression/test_pearson.py +++ b/tests/unittests/regression/test_pearson.py @@ -149,3 +149,20 @@ def test_pearsons_warning_on_small_input(dtype, scale): target = scale * torch.randn(100, dtype=dtype) with pytest.warns(UserWarning, match="The variance of predictions or target is close to zero.*"): pearson_corrcoef(preds, target) + + +def test_single_sample_update(): + """See issue: https://github.com/Lightning-AI/torchmetrics/issues/2014.""" + metric = PearsonCorrCoef() + + # Works + metric(torch.tensor([3.0, -0.5, 2.0, 7.0]), torch.tensor([2.5, 0.0, 2.0, 8.0])) + res1 = metric.compute() + metric.reset() + + metric(torch.tensor([3.0]), torch.tensor([2.5])) + metric(torch.tensor([-0.5]), torch.tensor([0.0])) + metric(torch.tensor([2.0]), torch.tensor([2.0])) + metric(torch.tensor([7.0]), torch.tensor([8.0])) + res2 = metric.compute() + assert torch.allclose(res1, res2)