From b726b2e71dcc895becae3fb23f3d74559011919b Mon Sep 17 00:00:00 2001 From: Shion Date: Sat, 19 Aug 2023 20:58:21 +0900 Subject: [PATCH 01/25] working implementation --- .../functional/clustering/__init__.py | 5 + .../clustering/mutual_info_score.py | 119 ++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 src/torchmetrics/functional/clustering/__init__.py create mode 100644 src/torchmetrics/functional/clustering/mutual_info_score.py diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py new file mode 100644 index 00000000000..322b4856620 --- /dev/null +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -0,0 +1,5 @@ +from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score + +__all__ = [ + "mutual_info_score" +] diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py new file mode 100644 index 00000000000..c5ecc323552 --- /dev/null +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -0,0 +1,119 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from typing import Optional, Tuple +from torch import Tensor, tensor + +from torchmetrics.utilities.checks import _check_same_shape + + +def _mutual_info_score_check(preds, target) -> bool: + """Check shape of input tensors.""" + # TODO: check if data are disjoint subsets + return _check_same_shape(preds, target) + + +def _calculate_contingency_matrix( + preds: Tensor, + target: Tensor, + eps: Optional[float] = 1e-16, + sparse: bool = False +) -> Tensor: + """Calculate contingency matrix. + + Args: + preds: predicted labels + target: ground truth labels + sparse: If True, returns contingency matrix as a sparse matrix. + + Returns: + contingency: contingency matrix of shape (n_classes_target, n_classes_preds) + """ + if eps is not None and sparse is True: + raise ValueError('Cannot specify `eps` and return sparse tensor.') + + preds_classes, preds_idx = torch.unique(preds, return_inverse=True) + target_classes, target_idx = torch.unique(target, return_inverse=True) + + n_classes_preds = preds_classes.size(0) + n_classes_target = target_classes.size(0) + + contingency = torch.sparse_coo_tensor( + torch.stack((target_idx, preds_idx)), + torch.ones(target_idx.size(0)), + (n_classes_target, n_classes_preds) + ) + + if not sparse: + contingency = contingency.to_dense() + if eps: + contingency = contingency + eps + + return contingency + + +def _mutual_info_score_update(preds, target) -> Tuple[Tensor, Tensor, Tensor]: + """Update and return variables required to compute the mutual information score. + + Args: + preds: predicted class labels + target: ground truth class labels + + Returns: + contingency: contingency matrix + """ + _mutual_info_score_check(preds, target) + return _calculate_contingency_matrix(preds, target) + + +def _mutual_info_score_compute(contingency: Tensor) -> Tensor: + """Compute the mutual information score based on the contingency matrix. + + Args: + contingency: contingency matrix + + Returns: + mutual_info: mutual information score + """ + N = contingency.sum() + U = contingency.sum(dim=1) + V = contingency.sum(dim=0) + + # Check if preds or target labels only have one cluster + if U.size() == 1 or V.size() == 1: + return tensor(0.0) + + log_outer = torch.log(U).reshape(-1, 1) + torch.log(V) + mutual_info = contingency / N * (torch.log(N) + torch.log(contingency) - log_outer) + return mutual_info.sum() + + +def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor: + """Compute mutual information between two clusterings. + + Args: + preds: predicted classes + target: ground truth classes + + Example: + >>> from torchmetrics.functional.clustering import mutual_info_score + >>> target = torch.tensor([0, 3, 2, 2, 1]) + >>> preds = torch.tensor([1, 3, 2, 0, 1]) + >>> mutual_info_score(preds, target) + tensor([1.05492]) + """ + _mutual_info_score_check(preds, target) + contingency = _mutual_info_score_update(preds, target) + return _mutual_info_score_compute(contingency) From a065ef1338e919abf26d6b46146b63a5c8f2c48f Mon Sep 17 00:00:00 2001 From: Shion Date: Sat, 19 Aug 2023 22:33:28 +0900 Subject: [PATCH 02/25] passing functional and basic error tests --- .../clustering/mutual_info_score.py | 22 +++-- tests/unittests/clustering/__init__.py | 0 .../clustering/test_mutual_info_score.py | 88 +++++++++++++++++++ 3 files changed, 104 insertions(+), 6 deletions(-) create mode 100644 tests/unittests/clustering/__init__.py create mode 100644 tests/unittests/clustering/test_mutual_info_score.py diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index c5ecc323552..0402eb3382c 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -19,10 +19,14 @@ from torchmetrics.utilities.checks import _check_same_shape -def _mutual_info_score_check(preds, target) -> bool: +def check_cluster_labels(preds: Tensor, target: Tensor) -> None: """Check shape of input tensors.""" - # TODO: check if data are disjoint subsets - return _check_same_shape(preds, target) + _check_same_shape(preds, target) + if torch.is_floating_point(preds) or torch.is_floating_point(target): + raise ValueError( + f"Expected discrete values but received {preds.dtype} for" + f"predictions and {target.dtype} for target labels instead." + ) def _calculate_contingency_matrix( @@ -40,6 +44,7 @@ def _calculate_contingency_matrix( Returns: contingency: contingency matrix of shape (n_classes_target, n_classes_preds) + """ if eps is not None and sparse is True: raise ValueError('Cannot specify `eps` and return sparse tensor.') @@ -64,7 +69,10 @@ def _calculate_contingency_matrix( return contingency -def _mutual_info_score_update(preds, target) -> Tuple[Tensor, Tensor, Tensor]: +def _mutual_info_score_update( + preds: Tensor, + target: Tensor +) -> Tuple[Tensor, Tensor, Tensor]: """Update and return variables required to compute the mutual information score. Args: @@ -73,8 +81,9 @@ def _mutual_info_score_update(preds, target) -> Tuple[Tensor, Tensor, Tensor]: Returns: contingency: contingency matrix + """ - _mutual_info_score_check(preds, target) + check_cluster_labels(preds, target) return _calculate_contingency_matrix(preds, target) @@ -86,6 +95,7 @@ def _mutual_info_score_compute(contingency: Tensor) -> Tensor: Returns: mutual_info: mutual information score + """ N = contingency.sum() U = contingency.sum(dim=1) @@ -113,7 +123,7 @@ def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor: >>> preds = torch.tensor([1, 3, 2, 0, 1]) >>> mutual_info_score(preds, target) tensor([1.05492]) + """ - _mutual_info_score_check(preds, target) contingency = _mutual_info_score_update(preds, target) return _mutual_info_score_compute(contingency) diff --git a/tests/unittests/clustering/__init__.py b/tests/unittests/clustering/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py new file mode 100644 index 00000000000..a44fead5751 --- /dev/null +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -0,0 +1,88 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import torch +from sklearn.metrics import mutual_info_score as scipy_mutual_info_score +from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score +from torchmetrics.clustering.mutual_info_score import MutualInfoScore + +from unittests import BATCH_SIZE, NUM_BATCHES +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +seed_all(42) + +Input = namedtuple("Input", ["preds", "target"]) +NUM_CLASSES = 10 + +_single_target_inputs1 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_single_target_inputs2 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_float_inputs = Input( + preds=torch.rand((NUM_BATCHES, BATCH_SIZE)), + target=torch.rand((NUM_BATCHES, BATCH_SIZE)), +) + + +@pytest.mark.parametrize( + "preds, target", + [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs2.preds, _single_target_inputs2.target), + ], +) +class TestMutualInfoScore(MetricTester): + """Test class for `MutualInfoScore` metric.""" + + atol = 1e-3 + + @pytest.mark.parametrize("compute_on_cpu", [True, False]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_mutual_info_score(self, preds, target, compute_on_cpu, ddp): + """Test class implementation of metric.""" + metric_args = {"num_classes": NUM_CLASSES} + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MutualInfoScore, + reference_metric=scipy_mutual_info_score, + metric_args=metric_args + ) + + def test_mutual_info_score_functional(self, preds, target): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=mutual_info_score, + reference_metric=scipy_mutual_info_score + ) + + +def test_mutual_info_score_functional_raises_invalid_task(): + """Check that metric rejects continuous-valued inputs.""" + preds, target = _float_inputs + with pytest.raises(ValueError, match=r"Expected discrete *"): + mutual_info_score(preds, target) From f355a3be77ea9b1bff92d622fc8ea87d8a1c7b06 Mon Sep 17 00:00:00 2001 From: Shion Date: Sat, 19 Aug 2023 20:58:21 +0900 Subject: [PATCH 03/25] working implementation --- .../functional/clustering/__init__.py | 5 + .../clustering/mutual_info_score.py | 119 ++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 src/torchmetrics/functional/clustering/__init__.py create mode 100644 src/torchmetrics/functional/clustering/mutual_info_score.py diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py new file mode 100644 index 00000000000..322b4856620 --- /dev/null +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -0,0 +1,5 @@ +from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score + +__all__ = [ + "mutual_info_score" +] diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py new file mode 100644 index 00000000000..c5ecc323552 --- /dev/null +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -0,0 +1,119 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from typing import Optional, Tuple +from torch import Tensor, tensor + +from torchmetrics.utilities.checks import _check_same_shape + + +def _mutual_info_score_check(preds, target) -> bool: + """Check shape of input tensors.""" + # TODO: check if data are disjoint subsets + return _check_same_shape(preds, target) + + +def _calculate_contingency_matrix( + preds: Tensor, + target: Tensor, + eps: Optional[float] = 1e-16, + sparse: bool = False +) -> Tensor: + """Calculate contingency matrix. + + Args: + preds: predicted labels + target: ground truth labels + sparse: If True, returns contingency matrix as a sparse matrix. + + Returns: + contingency: contingency matrix of shape (n_classes_target, n_classes_preds) + """ + if eps is not None and sparse is True: + raise ValueError('Cannot specify `eps` and return sparse tensor.') + + preds_classes, preds_idx = torch.unique(preds, return_inverse=True) + target_classes, target_idx = torch.unique(target, return_inverse=True) + + n_classes_preds = preds_classes.size(0) + n_classes_target = target_classes.size(0) + + contingency = torch.sparse_coo_tensor( + torch.stack((target_idx, preds_idx)), + torch.ones(target_idx.size(0)), + (n_classes_target, n_classes_preds) + ) + + if not sparse: + contingency = contingency.to_dense() + if eps: + contingency = contingency + eps + + return contingency + + +def _mutual_info_score_update(preds, target) -> Tuple[Tensor, Tensor, Tensor]: + """Update and return variables required to compute the mutual information score. + + Args: + preds: predicted class labels + target: ground truth class labels + + Returns: + contingency: contingency matrix + """ + _mutual_info_score_check(preds, target) + return _calculate_contingency_matrix(preds, target) + + +def _mutual_info_score_compute(contingency: Tensor) -> Tensor: + """Compute the mutual information score based on the contingency matrix. + + Args: + contingency: contingency matrix + + Returns: + mutual_info: mutual information score + """ + N = contingency.sum() + U = contingency.sum(dim=1) + V = contingency.sum(dim=0) + + # Check if preds or target labels only have one cluster + if U.size() == 1 or V.size() == 1: + return tensor(0.0) + + log_outer = torch.log(U).reshape(-1, 1) + torch.log(V) + mutual_info = contingency / N * (torch.log(N) + torch.log(contingency) - log_outer) + return mutual_info.sum() + + +def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor: + """Compute mutual information between two clusterings. + + Args: + preds: predicted classes + target: ground truth classes + + Example: + >>> from torchmetrics.functional.clustering import mutual_info_score + >>> target = torch.tensor([0, 3, 2, 2, 1]) + >>> preds = torch.tensor([1, 3, 2, 0, 1]) + >>> mutual_info_score(preds, target) + tensor([1.05492]) + """ + _mutual_info_score_check(preds, target) + contingency = _mutual_info_score_update(preds, target) + return _mutual_info_score_compute(contingency) From e6862da1b1b793e501440354597aadaf26735d39 Mon Sep 17 00:00:00 2001 From: Shion Date: Sat, 19 Aug 2023 22:33:28 +0900 Subject: [PATCH 04/25] passing functional and basic error tests --- .../clustering/mutual_info_score.py | 22 +++-- tests/unittests/clustering/__init__.py | 0 .../clustering/test_mutual_info_score.py | 88 +++++++++++++++++++ 3 files changed, 104 insertions(+), 6 deletions(-) create mode 100644 tests/unittests/clustering/__init__.py create mode 100644 tests/unittests/clustering/test_mutual_info_score.py diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index c5ecc323552..0402eb3382c 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -19,10 +19,14 @@ from torchmetrics.utilities.checks import _check_same_shape -def _mutual_info_score_check(preds, target) -> bool: +def check_cluster_labels(preds: Tensor, target: Tensor) -> None: """Check shape of input tensors.""" - # TODO: check if data are disjoint subsets - return _check_same_shape(preds, target) + _check_same_shape(preds, target) + if torch.is_floating_point(preds) or torch.is_floating_point(target): + raise ValueError( + f"Expected discrete values but received {preds.dtype} for" + f"predictions and {target.dtype} for target labels instead." + ) def _calculate_contingency_matrix( @@ -40,6 +44,7 @@ def _calculate_contingency_matrix( Returns: contingency: contingency matrix of shape (n_classes_target, n_classes_preds) + """ if eps is not None and sparse is True: raise ValueError('Cannot specify `eps` and return sparse tensor.') @@ -64,7 +69,10 @@ def _calculate_contingency_matrix( return contingency -def _mutual_info_score_update(preds, target) -> Tuple[Tensor, Tensor, Tensor]: +def _mutual_info_score_update( + preds: Tensor, + target: Tensor +) -> Tuple[Tensor, Tensor, Tensor]: """Update and return variables required to compute the mutual information score. Args: @@ -73,8 +81,9 @@ def _mutual_info_score_update(preds, target) -> Tuple[Tensor, Tensor, Tensor]: Returns: contingency: contingency matrix + """ - _mutual_info_score_check(preds, target) + check_cluster_labels(preds, target) return _calculate_contingency_matrix(preds, target) @@ -86,6 +95,7 @@ def _mutual_info_score_compute(contingency: Tensor) -> Tensor: Returns: mutual_info: mutual information score + """ N = contingency.sum() U = contingency.sum(dim=1) @@ -113,7 +123,7 @@ def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor: >>> preds = torch.tensor([1, 3, 2, 0, 1]) >>> mutual_info_score(preds, target) tensor([1.05492]) + """ - _mutual_info_score_check(preds, target) contingency = _mutual_info_score_update(preds, target) return _mutual_info_score_compute(contingency) diff --git a/tests/unittests/clustering/__init__.py b/tests/unittests/clustering/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py new file mode 100644 index 00000000000..a44fead5751 --- /dev/null +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -0,0 +1,88 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import torch +from sklearn.metrics import mutual_info_score as scipy_mutual_info_score +from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score +from torchmetrics.clustering.mutual_info_score import MutualInfoScore + +from unittests import BATCH_SIZE, NUM_BATCHES +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +seed_all(42) + +Input = namedtuple("Input", ["preds", "target"]) +NUM_CLASSES = 10 + +_single_target_inputs1 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_single_target_inputs2 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_float_inputs = Input( + preds=torch.rand((NUM_BATCHES, BATCH_SIZE)), + target=torch.rand((NUM_BATCHES, BATCH_SIZE)), +) + + +@pytest.mark.parametrize( + "preds, target", + [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs2.preds, _single_target_inputs2.target), + ], +) +class TestMutualInfoScore(MetricTester): + """Test class for `MutualInfoScore` metric.""" + + atol = 1e-3 + + @pytest.mark.parametrize("compute_on_cpu", [True, False]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_mutual_info_score(self, preds, target, compute_on_cpu, ddp): + """Test class implementation of metric.""" + metric_args = {"num_classes": NUM_CLASSES} + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MutualInfoScore, + reference_metric=scipy_mutual_info_score, + metric_args=metric_args + ) + + def test_mutual_info_score_functional(self, preds, target): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=mutual_info_score, + reference_metric=scipy_mutual_info_score + ) + + +def test_mutual_info_score_functional_raises_invalid_task(): + """Check that metric rejects continuous-valued inputs.""" + preds, target = _float_inputs + with pytest.raises(ValueError, match=r"Expected discrete *"): + mutual_info_score(preds, target) From fbfae57e098aab61f8efef95fcc5570854f842d8 Mon Sep 17 00:00:00 2001 From: Shion Date: Mon, 21 Aug 2023 23:57:55 +0900 Subject: [PATCH 05/25] clean up naming and imports --- .../functional/clustering/mutual_info_score.py | 12 +++++++----- .../unittests/clustering/test_mutual_info_score.py | 14 +++++++------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index 0402eb3382c..badfd3aaa02 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -19,12 +19,13 @@ from torchmetrics.utilities.checks import _check_same_shape -def check_cluster_labels(preds: Tensor, target: Tensor) -> None: +def _check_cluster_labels(preds: Tensor, target: Tensor) -> None: """Check shape of input tensors.""" _check_same_shape(preds, target) - if torch.is_floating_point(preds) or torch.is_floating_point(target): + if torch.is_floating_point(preds) or torch.is_complex(preds) or \ + torch.is_floating_point(target) or torch.is_complex(target): raise ValueError( - f"Expected discrete values but received {preds.dtype} for" + f"Expected real, discrete values but received {preds.dtype} for" f"predictions and {target.dtype} for target labels instead." ) @@ -71,7 +72,8 @@ def _calculate_contingency_matrix( def _mutual_info_score_update( preds: Tensor, - target: Tensor + target: Tensor, + # num_classes: int ) -> Tuple[Tensor, Tensor, Tensor]: """Update and return variables required to compute the mutual information score. @@ -83,7 +85,7 @@ def _mutual_info_score_update( contingency: contingency matrix """ - check_cluster_labels(preds, target) + _check_cluster_labels(preds, target) return _calculate_contingency_matrix(preds, target) diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index a44fead5751..4ec74ad28c2 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -16,7 +16,7 @@ import pytest import torch -from sklearn.metrics import mutual_info_score as scipy_mutual_info_score +from sklearn.metrics import mutual_info_score as sklearn_mutual_info_score from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score from torchmetrics.clustering.mutual_info_score import MutualInfoScore @@ -55,20 +55,20 @@ class TestMutualInfoScore(MetricTester): """Test class for `MutualInfoScore` metric.""" - atol = 1e-3 + atol = 1e-5 @pytest.mark.parametrize("compute_on_cpu", [True, False]) @pytest.mark.parametrize("ddp", [True, False]) def test_mutual_info_score(self, preds, target, compute_on_cpu, ddp): """Test class implementation of metric.""" - metric_args = {"num_classes": NUM_CLASSES} + # metric_args = {"num_classes": NUM_CLASSES} self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=MutualInfoScore, - reference_metric=scipy_mutual_info_score, - metric_args=metric_args + reference_metric=sklearn_mutual_info_score, + # metric_args=metric_args ) def test_mutual_info_score_functional(self, preds, target): @@ -77,12 +77,12 @@ def test_mutual_info_score_functional(self, preds, target): preds=preds, target=target, metric_functional=mutual_info_score, - reference_metric=scipy_mutual_info_score + reference_metric=sklearn_mutual_info_score, ) def test_mutual_info_score_functional_raises_invalid_task(): """Check that metric rejects continuous-valued inputs.""" preds, target = _float_inputs - with pytest.raises(ValueError, match=r"Expected discrete *"): + with pytest.raises(ValueError, match=r"Expected *"): mutual_info_score(preds, target) From f72183d6471cb1890ee87e8643ef15e68c72e643 Mon Sep 17 00:00:00 2001 From: Shion Date: Mon, 21 Aug 2023 23:58:25 +0900 Subject: [PATCH 06/25] push metric class (broken but to allow review) --- src/torchmetrics/clustering/__init__.py | 19 +++ .../clustering/mutual_info_score.py | 130 ++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 src/torchmetrics/clustering/__init__.py create mode 100644 src/torchmetrics/clustering/mutual_info_score.py diff --git a/src/torchmetrics/clustering/__init__.py b/src/torchmetrics/clustering/__init__.py new file mode 100644 index 00000000000..e86cc406cb1 --- /dev/null +++ b/src/torchmetrics/clustering/__init__.py @@ -0,0 +1,19 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.clustering.mutual_info_score import MutualInfoScore + + +__all__ = [ + "MutualInfoScore", +] diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py new file mode 100644 index 00000000000..a1c781e2cae --- /dev/null +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -0,0 +1,130 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from typing import Any, Optional, List, Sequence, Union +from torch import Tensor + +from torchmetrics.functional.clustering.mutual_info_score import ( + _mutual_info_score_compute, + _mutual_info_score_update +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MutualInfoScore.plot"] + + +class MutualInfoScore(Metric): + r"""Compute `Mutual Information Score`_. + + .. math:: + MI(U,V) = \sum_{i=1}^{\abs{U}} \sum_{j=1}^{\abs{V}} \frac{\abs{U_i\cap V_j}}{N} \log\frac{N\abs{U_i\cap V_j}}{\abs{U_i}\abs{V_j}} + + Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions, + :math:`\abs{U_i}` is the number of samples in cluster :math:`U_i`, and + :math:`\abs{V_i}` is the number of samples in cluster :math:`V_i`. + + The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields + the same mutual information score. + + Args: + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)`` + - ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)`` + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Mutual Information Score + + Example: + >>> from torchmetrics.clustering import MutualInfoScore + >>> target = torch.tensor([]) + >>> preds = torch.tensor([]) + >>> mi_score = MutualInfoScore() + >>> mi_score(preds, target) + tensor() + """ + + is_differentiable = True + higher_is_better = None + full_state_update: bool = True + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 # theoretical upper bound is +inf + preds: List[Tensor] + target: List[Tensor] + contingency: Tensor + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + # self.num_classes = num_classes + # + # self.add_state("contingency", default=torch.zeros(self.num_classes), dist_reduce_fx=None) + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.contingency = _mutual_info_score_update(preds, target) + + def compute(self) -> Tensor: + """Compute mutual information over state.""" + return _mutual_info_score_compute(self.contingency) + + def plot( + self, + val: Union[Tensor, Sequence[Tensor], None] = None, + ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.clustering import MutualInfoScore + >>> metric = MutualInfoScore(num_classes=5) + >>> metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.clustering import MutualInfoScore + >>> metric = MutualInfoScore(num_classes=5) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,)))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) From 7fe14e0fb6a89435a0e2be6792917a98c26d2058 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Aug 2023 15:01:58 +0000 Subject: [PATCH 07/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/clustering/__init__.py | 1 - .../clustering/mutual_info_score.py | 16 ++++--------- .../functional/clustering/__init__.py | 4 +--- .../clustering/mutual_info_score.py | 23 +++++++++---------- .../clustering/test_mutual_info_score.py | 2 +- 5 files changed, 18 insertions(+), 28 deletions(-) diff --git a/src/torchmetrics/clustering/__init__.py b/src/torchmetrics/clustering/__init__.py index e86cc406cb1..baeb8c88d31 100644 --- a/src/torchmetrics/clustering/__init__.py +++ b/src/torchmetrics/clustering/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. from torchmetrics.clustering.mutual_info_score import MutualInfoScore - __all__ = [ "MutualInfoScore", ] diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index a1c781e2cae..7605d12b08b 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -11,15 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch +from typing import Any, List, Optional, Sequence, Union -from typing import Any, Optional, List, Sequence, Union +import torch from torch import Tensor -from torchmetrics.functional.clustering.mutual_info_score import ( - _mutual_info_score_compute, - _mutual_info_score_update -) +from torchmetrics.functional.clustering.mutual_info_score import _mutual_info_score_compute, _mutual_info_score_update from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE @@ -60,6 +57,7 @@ class MutualInfoScore(Metric): >>> mi_score = MutualInfoScore() >>> mi_score(preds, target) tensor() + """ is_differentiable = True @@ -85,11 +83,7 @@ def compute(self) -> Tensor: """Compute mutual information over state.""" return _mutual_info_score_compute(self.contingency) - def plot( - self, - val: Union[Tensor, Sequence[Tensor], None] = None, - ax: Optional[_AX_TYPE] = None - ) -> _PLOT_OUT_TYPE: + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py index 322b4856620..576acda5108 100644 --- a/src/torchmetrics/functional/clustering/__init__.py +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -1,5 +1,3 @@ from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score -__all__ = [ - "mutual_info_score" -] +__all__ = ["mutual_info_score"] diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index badfd3aaa02..180aa4496a3 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch - from typing import Optional, Tuple + +import torch from torch import Tensor, tensor from torchmetrics.utilities.checks import _check_same_shape @@ -22,8 +22,12 @@ def _check_cluster_labels(preds: Tensor, target: Tensor) -> None: """Check shape of input tensors.""" _check_same_shape(preds, target) - if torch.is_floating_point(preds) or torch.is_complex(preds) or \ - torch.is_floating_point(target) or torch.is_complex(target): + if ( + torch.is_floating_point(preds) + or torch.is_complex(preds) + or torch.is_floating_point(target) + or torch.is_complex(target) + ): raise ValueError( f"Expected real, discrete values but received {preds.dtype} for" f"predictions and {target.dtype} for target labels instead." @@ -31,10 +35,7 @@ def _check_cluster_labels(preds: Tensor, target: Tensor) -> None: def _calculate_contingency_matrix( - preds: Tensor, - target: Tensor, - eps: Optional[float] = 1e-16, - sparse: bool = False + preds: Tensor, target: Tensor, eps: Optional[float] = 1e-16, sparse: bool = False ) -> Tensor: """Calculate contingency matrix. @@ -48,7 +49,7 @@ def _calculate_contingency_matrix( """ if eps is not None and sparse is True: - raise ValueError('Cannot specify `eps` and return sparse tensor.') + raise ValueError("Cannot specify `eps` and return sparse tensor.") preds_classes, preds_idx = torch.unique(preds, return_inverse=True) target_classes, target_idx = torch.unique(target, return_inverse=True) @@ -57,9 +58,7 @@ def _calculate_contingency_matrix( n_classes_target = target_classes.size(0) contingency = torch.sparse_coo_tensor( - torch.stack((target_idx, preds_idx)), - torch.ones(target_idx.size(0)), - (n_classes_target, n_classes_preds) + torch.stack((target_idx, preds_idx)), torch.ones(target_idx.size(0)), (n_classes_target, n_classes_preds) ) if not sparse: diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index 4ec74ad28c2..d594d0f140a 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -17,8 +17,8 @@ import pytest import torch from sklearn.metrics import mutual_info_score as sklearn_mutual_info_score -from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score from torchmetrics.clustering.mutual_info_score import MutualInfoScore +from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score from unittests import BATCH_SIZE, NUM_BATCHES from unittests.helpers import seed_all From 808b2785ad651b0fe21b90801499545792597f2b Mon Sep 17 00:00:00 2001 From: Shion Date: Tue, 22 Aug 2023 00:44:26 +0900 Subject: [PATCH 08/25] add docs files --- docs/source/clustering/mutual_info_score.rst | 21 ++++++++++++++++++++ docs/source/index.rst | 8 ++++++++ 2 files changed, 29 insertions(+) create mode 100644 docs/source/clustering/mutual_info_score.rst diff --git a/docs/source/clustering/mutual_info_score.rst b/docs/source/clustering/mutual_info_score.rst new file mode 100644 index 00000000000..1b7d13519f4 --- /dev/null +++ b/docs/source/clustering/mutual_info_score.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Mutual Information Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/clustering.svg + :tags: Clustering + +.. include:: ../links.rst + +################### +Mutual Info. Score +################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.MutualInfoScore + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.mutual_info_score diff --git a/docs/source/index.rst b/docs/source/index.rst index 9da8cf0a51a..af7b6ff798b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -158,6 +158,14 @@ Or directly from conda classification/* +.. toctree:: + :maxdepth: 2 + :name: clustering + :caption: Clustering + :glob: + + clustering/* + .. toctree:: :maxdepth: 2 :name: detection From a0308d2800763c47f2acb7a87bfd49a2192af56e Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 22 Aug 2023 14:11:36 +0200 Subject: [PATCH 09/25] releasing 1.1.0 --- CHANGELOG.md | 36 +---------------------------------- src/torchmetrics/__about__.py | 2 +- 2 files changed, 2 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d4c262ac67..356aca21769 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,56 +7,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 **Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.** -## [UnReleased] - 2023-MM-DD +## [1.1.0] - 2023-08-22 ### Added - Added source aggregated signal-to-distortion ratio (SA-SDR) metric ([#1882](https://github.com/Lightning-AI/torchmetrics/pull/1882) - - - Added `VisualInformationFidelity` to image package ([#1830](https://github.com/Lightning-AI/torchmetrics/pull/1830)) - - - Added `EditDistance` to text package ([#1906](https://github.com/Lightning-AI/torchmetrics/pull/1906)) - - - Added `top_k` argument to `RetrievalMRR` in retrieval package ([#1961](https://github.com/Lightning-AI/torchmetrics/pull/1961)) - - - Added support for evaluating `"segm"` and `"bbox"` detection in `MeanAveragePrecision` at the same time ([#1928](https://github.com/Lightning-AI/torchmetrics/pull/1928)) - - - Added `PerceptualPathLength` to image package ([#1939](https://github.com/Lightning-AI/torchmetrics/pull/1939)) - - - Added support for multioutput evaluation in `MeanSquaredError` ([#1937](https://github.com/Lightning-AI/torchmetrics/pull/1937)) - - - Added argument `extended_summary` to `MeanAveragePrecision` such that precision, recall, iou can be easily returned ([#1983](https://github.com/Lightning-AI/torchmetrics/pull/1983)) - - - Added warning to `ClipScore` if long captions are detected and truncate ([#2001](https://github.com/Lightning-AI/torchmetrics/pull/2001)) - - - Added `CLIPImageQualityAssessment` to multimodal package ([#1931](https://github.com/Lightning-AI/torchmetrics/pull/1931)) - - - Added new property `metric_state` to all metrics for users to investigate currently stored tensors in memory ([#2006](https://github.com/Lightning-AI/torchmetrics/pull/2006)) -### Changed - -- - - -### Removed - -- - - -### Fixed - -- - ## [1.0.3] - 2023-08-08 diff --git a/src/torchmetrics/__about__.py b/src/torchmetrics/__about__.py index 47ee37bceba..2115111fa44 100644 --- a/src/torchmetrics/__about__.py +++ b/src/torchmetrics/__about__.py @@ -1,4 +1,4 @@ -__version__ = "1.1.0.dev" +__version__ = "1.1.0" __author__ = "Lightning-AI et al." __author_email__ = "name@pytorchlightning.ai" __license__ = "Apache-2.0" From 0d3fec9b80dc33b71b2777f0c2b51468ebc00a20 Mon Sep 17 00:00:00 2001 From: Shion Date: Tue, 22 Aug 2023 23:59:20 +0900 Subject: [PATCH 10/25] Create util functions for clustering. Fix metric implementation. --- docs/source/clustering/mutual_info_score.rst | 6 +- docs/source/links.rst | 1 + .../clustering/mutual_info_score.py | 14 ++-- .../functional/clustering/__init__.py | 13 ++++ .../clustering/mutual_info_score.py | 59 +-------------- .../functional/clustering/utils.py | 74 +++++++++++++++++++ 6 files changed, 103 insertions(+), 64 deletions(-) create mode 100644 src/torchmetrics/functional/clustering/utils.py diff --git a/docs/source/clustering/mutual_info_score.rst b/docs/source/clustering/mutual_info_score.rst index 1b7d13519f4..39291801ae9 100644 --- a/docs/source/clustering/mutual_info_score.rst +++ b/docs/source/clustering/mutual_info_score.rst @@ -5,9 +5,9 @@ .. include:: ../links.rst -################### -Mutual Info. Score -################### +######################## +Mutual Information Score +######################## Module Interface ________________ diff --git a/docs/source/links.rst b/docs/source/links.rst index 4ca837ccd64..7627490c661 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -150,4 +150,5 @@ .. _CIOU: https://arxiv.org/abs/2005.03572 .. _DIOU: https://arxiv.org/abs/1911.08287v1 .. _GIOU: https://arxiv.org/abs/1902.09630 +.. _Mutual Information Score: https://en.wikipedia.org/wiki/Mutual_information .. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index 7605d12b08b..ff999d7bb0b 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -16,8 +16,9 @@ import torch from torch import Tensor -from torchmetrics.functional.clustering.mutual_info_score import _mutual_info_score_compute, _mutual_info_score_update +from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE @@ -71,17 +72,18 @@ class MutualInfoScore(Metric): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - # self.num_classes = num_classes - # - # self.add_state("contingency", default=torch.zeros(self.num_classes), dist_reduce_fx=None) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" - self.contingency = _mutual_info_score_update(preds, target) + self.preds.append(preds) + self.target.append(target) def compute(self) -> Tensor: """Compute mutual information over state.""" - return _mutual_info_score_compute(self.contingency) + return mutual_info_score(dim_zero_cat(self.preds), dim_zero_cat(self.target)) def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py index 576acda5108..c6f46126ca3 100644 --- a/src/torchmetrics/functional/clustering/__init__.py +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -1,3 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score __all__ = ["mutual_info_score"] diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index 180aa4496a3..6a61e2ac40d 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -11,62 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Tuple import torch from torch import Tensor, tensor - -from torchmetrics.utilities.checks import _check_same_shape - - -def _check_cluster_labels(preds: Tensor, target: Tensor) -> None: - """Check shape of input tensors.""" - _check_same_shape(preds, target) - if ( - torch.is_floating_point(preds) - or torch.is_complex(preds) - or torch.is_floating_point(target) - or torch.is_complex(target) - ): - raise ValueError( - f"Expected real, discrete values but received {preds.dtype} for" - f"predictions and {target.dtype} for target labels instead." - ) - - -def _calculate_contingency_matrix( - preds: Tensor, target: Tensor, eps: Optional[float] = 1e-16, sparse: bool = False -) -> Tensor: - """Calculate contingency matrix. - - Args: - preds: predicted labels - target: ground truth labels - sparse: If True, returns contingency matrix as a sparse matrix. - - Returns: - contingency: contingency matrix of shape (n_classes_target, n_classes_preds) - - """ - if eps is not None and sparse is True: - raise ValueError("Cannot specify `eps` and return sparse tensor.") - - preds_classes, preds_idx = torch.unique(preds, return_inverse=True) - target_classes, target_idx = torch.unique(target, return_inverse=True) - - n_classes_preds = preds_classes.size(0) - n_classes_target = target_classes.size(0) - - contingency = torch.sparse_coo_tensor( - torch.stack((target_idx, preds_idx)), torch.ones(target_idx.size(0)), (n_classes_target, n_classes_preds) - ) - - if not sparse: - contingency = contingency.to_dense() - if eps: - contingency = contingency + eps - - return contingency +from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels def _mutual_info_score_update( @@ -84,8 +33,8 @@ def _mutual_info_score_update( contingency: contingency matrix """ - _check_cluster_labels(preds, target) - return _calculate_contingency_matrix(preds, target) + check_cluster_labels(preds, target) + return calculate_contingency_matrix(preds, target) def _mutual_info_score_compute(contingency: Tensor) -> Tensor: diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py new file mode 100644 index 00000000000..334584b8ff4 --- /dev/null +++ b/src/torchmetrics/functional/clustering/utils.py @@ -0,0 +1,74 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from torch import Tensor +from torchmetrics.utilities.checks import _check_same_shape +from typing import Optional + + +def calculate_contingency_matrix( + preds: Tensor, target: Tensor, eps: Optional[float] = 1e-16, sparse: bool = False +) -> Tensor: + """Calculate contingency matrix. + + Args: + preds: predicted labels + target: ground truth labels + sparse: If True, returns contingency matrix as a sparse matrix. + + Returns: + contingency: contingency matrix of shape (n_classes_target, n_classes_preds) + + """ + if eps is not None and sparse is True: + raise ValueError("Cannot specify `eps` and return sparse tensor.") + + preds_classes, preds_idx = torch.unique(preds, return_inverse=True) + target_classes, target_idx = torch.unique(target, return_inverse=True) + + n_classes_preds = preds_classes.size(0) + n_classes_target = target_classes.size(0) + + contingency = torch.sparse_coo_tensor( + torch.stack((target_idx, preds_idx)), torch.ones(target_idx.size(0)), (n_classes_target, n_classes_preds) + ) + + if not sparse: + contingency = contingency.to_dense() + if eps: + contingency = contingency + eps + + return contingency + + +def check_cluster_labels(preds: Tensor, target: Tensor) -> None: + """Check shape of input tensors and if they are real, discrete tensors. + + Args: + preds: predicted labels + target: ground truth labels + + """ + _check_same_shape(preds, target) + if ( + torch.is_floating_point(preds) + or torch.is_complex(preds) + or torch.is_floating_point(target) + or torch.is_complex(target) + ): + raise ValueError( + f"Expected real, discrete values but received {preds.dtype} for" + f"predictions and {target.dtype} for target labels instead." + ) From 7dad1f97f1e2927e923b38c3b277f3fa282b7f20 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Aug 2023 15:00:37 +0000 Subject: [PATCH 11/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/clustering/mutual_info_score.py | 1 + src/torchmetrics/functional/clustering/utils.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index 6a61e2ac40d..ac2ed03c41d 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -15,6 +15,7 @@ import torch from torch import Tensor, tensor + from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py index 334584b8ff4..a5460d22e58 100644 --- a/src/torchmetrics/functional/clustering/utils.py +++ b/src/torchmetrics/functional/clustering/utils.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch +from typing import Optional +import torch from torch import Tensor + from torchmetrics.utilities.checks import _check_same_shape -from typing import Optional def calculate_contingency_matrix( From c36d8a0e7c419fdf1d465117ced9c30c9a3ec95d Mon Sep 17 00:00:00 2001 From: Shion Date: Wed, 23 Aug 2023 00:15:20 +0900 Subject: [PATCH 12/25] Fix ruff-related errors --- src/torchmetrics/clustering/mutual_info_score.py | 4 ++-- .../functional/clustering/mutual_info_score.py | 14 +++++++------- src/torchmetrics/functional/clustering/utils.py | 9 ++++++--- .../unittests/clustering/test_mutual_info_score.py | 1 - 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index ff999d7bb0b..4e2a573870f 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Any, List, Optional, Sequence, Union -import torch from torch import Tensor from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score @@ -30,7 +29,8 @@ class MutualInfoScore(Metric): r"""Compute `Mutual Information Score`_. .. math:: - MI(U,V) = \sum_{i=1}^{\abs{U}} \sum_{j=1}^{\abs{V}} \frac{\abs{U_i\cap V_j}}{N} \log\frac{N\abs{U_i\cap V_j}}{\abs{U_i}\abs{V_j}} + MI(U,V) = \sum_{i=1}^{\abs{U}} \sum_{j=1}^{\abs{V}} \frac{\abs{U_i\cap V_j}}{N} + \log\frac{N\abs{U_i\cap V_j}}{\abs{U_i}\abs{V_j}} Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions, :math:`\abs{U_i}` is the number of samples in cluster :math:`U_i`, and diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index 6a61e2ac40d..05ab1af5605 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -15,13 +15,13 @@ import torch from torch import Tensor, tensor + from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels def _mutual_info_score_update( preds: Tensor, target: Tensor, - # num_classes: int ) -> Tuple[Tensor, Tensor, Tensor]: """Update and return variables required to compute the mutual information score. @@ -47,16 +47,16 @@ def _mutual_info_score_compute(contingency: Tensor) -> Tensor: mutual_info: mutual information score """ - N = contingency.sum() - U = contingency.sum(dim=1) - V = contingency.sum(dim=0) + n = contingency.sum() + u = contingency.sum(dim=1) + v = contingency.sum(dim=0) # Check if preds or target labels only have one cluster - if U.size() == 1 or V.size() == 1: + if u.size() == 1 or v.size() == 1: return tensor(0.0) - log_outer = torch.log(U).reshape(-1, 1) + torch.log(V) - mutual_info = contingency / N * (torch.log(N) + torch.log(contingency) - log_outer) + log_outer = torch.log(u).reshape(-1, 1) + torch.log(v) + mutual_info = contingency / n * (torch.log(n) + torch.log(contingency) - log_outer) return mutual_info.sum() diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py index 334584b8ff4..a0b35277a81 100644 --- a/src/torchmetrics/functional/clustering/utils.py +++ b/src/torchmetrics/functional/clustering/utils.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch +from typing import Optional +import torch from torch import Tensor + from torchmetrics.utilities.checks import _check_same_shape -from typing import Optional def calculate_contingency_matrix( @@ -26,7 +27,9 @@ def calculate_contingency_matrix( Args: preds: predicted labels target: ground truth labels - sparse: If True, returns contingency matrix as a sparse matrix. + eps: value added to contingency matrix + sparse: If True, returns contingency matrix as a sparse matrix. Else, return as dense matrix. + `eps` must be `None` if `sparse` is `True`. Returns: contingency: contingency matrix of shape (n_classes_target, n_classes_preds) diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index d594d0f140a..555539e4a91 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple -from functools import partial import pytest import torch From f677483e9b0d9d9c20850ef9e77d5ba8b9dbd1e4 Mon Sep 17 00:00:00 2001 From: Shion Date: Wed, 23 Aug 2023 00:34:24 +0900 Subject: [PATCH 13/25] Fix docstring examples --- .../clustering/mutual_info_score.py | 20 ++++++++++--------- .../clustering/mutual_info_score.py | 2 +- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index 4e2a573870f..51709716c58 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -52,12 +52,13 @@ class MutualInfoScore(Metric): - ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Mutual Information Score Example: + >>> import torch >>> from torchmetrics.clustering import MutualInfoScore - >>> target = torch.tensor([]) - >>> preds = torch.tensor([]) + >>> target = torch.tensor([0, 2, 1, 1, 0]) + >>> preds = torch.tensor([2, 1, 0, 1, 0]) >>> mi_score = MutualInfoScore() >>> mi_score(preds, target) - tensor() + tensor(0.5004) """ @@ -106,8 +107,9 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting a single value >>> import torch >>> from torchmetrics.clustering import MutualInfoScore - >>> metric = MutualInfoScore(num_classes=5) - >>> metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))) + >>> metric = MutualInfoScore() + >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> metric.compute() >>> fig_, ax_ = metric.plot() .. plot:: @@ -116,11 +118,11 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting multiple values >>> import torch >>> from torchmetrics.clustering import MutualInfoScore - >>> metric = MutualInfoScore(num_classes=5) - >>> values = [ ] + >>> metric = MutualInfoScore() >>> for _ in range(10): - ... values.append(metric(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,)))) - >>> fig_, ax_ = metric.plot(values) + ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))) + >>> metric.compute() + >>> fig_, ax_ = metric.plot() """ return self._plot(val, ax) diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index 05ab1af5605..7cf032b7ebb 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -72,7 +72,7 @@ def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor: >>> target = torch.tensor([0, 3, 2, 2, 1]) >>> preds = torch.tensor([1, 3, 2, 0, 1]) >>> mutual_info_score(preds, target) - tensor([1.05492]) + tensor(1.0549) """ contingency = _mutual_info_score_update(preds, target) From 0d361d16f5c2ccfc5d3228a69ffd527e02b73d1e Mon Sep 17 00:00:00 2001 From: Shion Date: Wed, 23 Aug 2023 00:56:32 +0900 Subject: [PATCH 14/25] Test functional metric for symmetry --- .../unittests/clustering/test_mutual_info_score.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index 555539e4a91..ca9fa1fa0b5 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -60,14 +60,12 @@ class TestMutualInfoScore(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) def test_mutual_info_score(self, preds, target, compute_on_cpu, ddp): """Test class implementation of metric.""" - # metric_args = {"num_classes": NUM_CLASSES} self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=MutualInfoScore, reference_metric=sklearn_mutual_info_score, - # metric_args=metric_args ) def test_mutual_info_score_functional(self, preds, target): @@ -85,3 +83,15 @@ def test_mutual_info_score_functional_raises_invalid_task(): preds, target = _float_inputs with pytest.raises(ValueError, match=r"Expected *"): mutual_info_score(preds, target) + + +@pytest.mark.parametrize( + ("preds", "target"), + [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + ], +) +def test_mutual_info_score_functional_is_symmetric(preds, target): + """Check that the metric funtional is symmetric.""" + for p, t in zip(preds, target): + assert torch.allclose(mutual_info_score(p, t), mutual_info_score(t, p)) From 422ace325f5364c736f4cbec62c24fbdf34ebdd8 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 23 Aug 2023 08:16:35 +0200 Subject: [PATCH 15/25] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 94a52a9a6fa..baef1be3009 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Added `MutualInformationScore` metric to cluster package ([#2008](https://github.com/Lightning-AI/torchmetrics/pull/2008) ### Changed From bf05b8bc5aa730d1ba1743c255ea5fd0b8e7390d Mon Sep 17 00:00:00 2001 From: Shion Date: Wed, 23 Aug 2023 22:47:23 +0900 Subject: [PATCH 16/25] Fix type hint error. Additional checks for tensor shapes. --- .../functional/clustering/mutual_info_score.py | 7 +------ src/torchmetrics/functional/clustering/utils.py | 2 ++ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index 7cf032b7ebb..f81ab46fd96 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -11,18 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple - import torch from torch import Tensor, tensor from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels -def _mutual_info_score_update( - preds: Tensor, - target: Tensor, -) -> Tuple[Tensor, Tensor, Tensor]: +def _mutual_info_score_update(preds: Tensor, target: Tensor) -> Tensor: """Update and return variables required to compute the mutual information score. Args: diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py index a0b35277a81..ae11a9c4524 100644 --- a/src/torchmetrics/functional/clustering/utils.py +++ b/src/torchmetrics/functional/clustering/utils.py @@ -65,6 +65,8 @@ def check_cluster_labels(preds: Tensor, target: Tensor) -> None: """ _check_same_shape(preds, target) + if preds.ndim != 1: + raise ValueError(f"Expected arguments to be 1d tensors but got {preds.ndim} and {target.ndim}") if ( torch.is_floating_point(preds) or torch.is_complex(preds) From e9a123319f6089e60ce28bd8a1b07799de9ee106 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 23 Aug 2023 22:48:52 +0900 Subject: [PATCH 17/25] Update src/torchmetrics/clustering/mutual_info_score.py Co-authored-by: Nicki Skafte Detlefsen --- src/torchmetrics/clustering/mutual_info_score.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index 51709716c58..fd56000ba6a 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -39,9 +39,6 @@ class MutualInfoScore(Metric): The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same mutual information score. - Args: - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)`` @@ -51,6 +48,9 @@ class MutualInfoScore(Metric): - ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Mutual Information Score + Args: + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + Example: >>> import torch >>> from torchmetrics.clustering import MutualInfoScore From 9cff8767c1109c27756ffdb60a66beff8964a584 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 23 Aug 2023 22:49:04 +0900 Subject: [PATCH 18/25] Update src/torchmetrics/clustering/mutual_info_score.py Co-authored-by: Nicki Skafte Detlefsen --- src/torchmetrics/clustering/mutual_info_score.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index fd56000ba6a..123052bd4ac 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -66,7 +66,6 @@ class MutualInfoScore(Metric): higher_is_better = None full_state_update: bool = True plot_lower_bound: float = 0.0 - plot_upper_bound: float = 1.0 # theoretical upper bound is +inf preds: List[Tensor] target: List[Tensor] contingency: Tensor From e4523d41d04783e4c12b42898ee79952464fc486 Mon Sep 17 00:00:00 2001 From: Shion Date: Thu, 24 Aug 2023 20:55:35 +0900 Subject: [PATCH 19/25] Test contingency matrix calculation --- .../functional/clustering/utils.py | 30 ++++++- tests/unittests/clustering/test_utils.py | 78 +++++++++++++++++++ 2 files changed, 104 insertions(+), 4 deletions(-) create mode 100644 tests/unittests/clustering/test_utils.py diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py index ae11a9c4524..7f685e8cbcc 100644 --- a/src/torchmetrics/functional/clustering/utils.py +++ b/src/torchmetrics/functional/clustering/utils.py @@ -20,7 +20,7 @@ def calculate_contingency_matrix( - preds: Tensor, target: Tensor, eps: Optional[float] = 1e-16, sparse: bool = False + preds: Tensor, target: Tensor, eps: Optional[float] = None, sparse: bool = False ) -> Tensor: """Calculate contingency matrix. @@ -34,9 +34,21 @@ def calculate_contingency_matrix( Returns: contingency: contingency matrix of shape (n_classes_target, n_classes_preds) + Example: + >>> import torch + >>> from torchmetrics.functional.clustering.utils import calculate_contingency_matrix + >>> preds = torch.tensor([2, 1, 0, 1, 0]) + >>> target = torch.tensor([0, 2, 1, 1, 0]) + >>> calculate_contingency_matrix(preds, target, eps=1e-16) + tensor([[1.0000e+00, 1.0000e-16, 1.0000e+00], + [1.0000e+00, 1.0000e+00, 1.0000e-16], + [1.0000e-16, 1.0000e+00, 1.0000e-16]]) + """ if eps is not None and sparse is True: raise ValueError("Cannot specify `eps` and return sparse tensor.") + if preds.ndim != 1 or target.ndim != 1: + raise ValueError(f"Expected 1d `preds` and `target` but got {preds.ndim} and {target.dim}.") preds_classes, preds_idx = torch.unique(preds, return_inverse=True) target_classes, target_idx = torch.unique(target, return_inverse=True) @@ -45,13 +57,23 @@ def calculate_contingency_matrix( n_classes_target = target_classes.size(0) contingency = torch.sparse_coo_tensor( - torch.stack((target_idx, preds_idx)), torch.ones(target_idx.size(0)), (n_classes_target, n_classes_preds) + torch.stack( + ( + target_idx, + preds_idx, + ) + ), + torch.ones(target_idx.size(0)), + ( + n_classes_target, + n_classes_preds, + ), ) if not sparse: contingency = contingency.to_dense() - if eps: - contingency = contingency + eps + if eps: + contingency = contingency + eps return contingency diff --git a/tests/unittests/clustering/test_utils.py b/tests/unittests/clustering/test_utils.py new file mode 100644 index 00000000000..95ee1a6a4a7 --- /dev/null +++ b/tests/unittests/clustering/test_utils.py @@ -0,0 +1,78 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple + +import numpy as np +import pytest +import torch +from sklearn.metrics.cluster import contingency_matrix as sklearn_contingency_matrix +from torchmetrics.functional.clustering.utils import calculate_contingency_matrix + +from unittests import BATCH_SIZE +from unittests.helpers import seed_all + +seed_all(42) + +Input = namedtuple("Input", ["preds", "target"]) +NUM_CLASSES = 10 + +_sklearn_inputs = Input( + preds=torch.tensor([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]), + target=torch.tensor([1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 3, 1, 3, 3, 3, 2, 2]), +) + +_single_dim_inputs = Input( + preds=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE,)), + target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE,)), +) + +_multi_dim_inputs = Input( + preds=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE, 2)), + target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE, 2)), +) + + +@pytest.mark.parametrize( + ("preds", "target"), + [(_sklearn_inputs.preds, _sklearn_inputs.target), (_single_dim_inputs.preds, _single_dim_inputs.target)], +) +class TestContingencyMatrix: + """Test calculation of dense and sparse contingency matrices.""" + + atol = 1e-8 + + @pytest.mark.parametrize("eps", [None, 1e-16]) + def test_contingency_matrix_dense(self, preds, target, eps): + """Check that dense contingency matrices are calculated correctly.""" + tm_c = calculate_contingency_matrix(preds, target, eps) + sklearn_c = sklearn_contingency_matrix(target, preds, eps=eps) + assert np.allclose(tm_c, sklearn_c, atol=self.atol) + + def test_contingency_matrix_sparse(self, preds, target): + """Check that sparse contingency matrices are calculated correctly.""" + tm_c = calculate_contingency_matrix(preds, target, sparse=True).to_dense().numpy() + sklearn_c = sklearn_contingency_matrix(target, preds, sparse=True).toarray() + assert np.allclose(tm_c, sklearn_c, atol=self.atol) + + +def test_eps_and_sparse_error(): + """Check that contingency matrix is not calculated if `eps` is nonzero and `sparse` is True.""" + with pytest.raises(ValueError, match="Cannot specify*"): + calculate_contingency_matrix(_single_dim_inputs.preds, _single_dim_inputs.target, eps=1e-16, sparse=True) + + +def test_multidimensional_contingency_error(): + """Check that contingency matrix is not calculated for multidimensional input.""" + with pytest.raises(ValueError, match="Expected 1d*"): + calculate_contingency_matrix(_multi_dim_inputs.preds, _multi_dim_inputs.target) From f1cc3df37e0b943177cbeb1212693776931e6d68 Mon Sep 17 00:00:00 2001 From: Shion Date: Thu, 24 Aug 2023 21:57:07 +0900 Subject: [PATCH 20/25] fix mutual info score calculation. all test passing. --- .../functional/clustering/mutual_info_score.py | 7 ++++++- tests/unittests/clustering/test_mutual_info_score.py | 8 ++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index f81ab46fd96..f7c7cbfa587 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -50,7 +50,12 @@ def _mutual_info_score_compute(contingency: Tensor) -> Tensor: if u.size() == 1 or v.size() == 1: return tensor(0.0) - log_outer = torch.log(u).reshape(-1, 1) + torch.log(v) + # Find indices of nonzero values in U and V + nzu, nzv = torch.nonzero(contingency, as_tuple=True) + contingency = contingency[nzu, nzv] + + # Calculate MI using entries corresponding to nonzero contingency matrix entries + log_outer = torch.log(u[nzu]) + torch.log(v[nzv]) mutual_info = contingency / n * (torch.log(n) + torch.log(contingency) - log_outer) return mutual_info.sum() diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index ca9fa1fa0b5..5c89b58b169 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -78,6 +78,14 @@ def test_mutual_info_score_functional(self, preds, target): ) +def test_mutual_info_score_functional_single_cluster(): + """Check that metric rejects continuous-valued inputs.""" + tensor_a = torch.randint(NUM_CLASSES, (BATCH_SIZE,)) + tensor_b = torch.zeros(BATCH_SIZE, dtype=torch.int) + assert torch.allclose(mutual_info_score(tensor_a, tensor_b), torch.tensor(0.0)) + assert torch.allclose(mutual_info_score(tensor_b, tensor_a), torch.tensor(0.0)) + + def test_mutual_info_score_functional_raises_invalid_task(): """Check that metric rejects continuous-valued inputs.""" preds, target = _float_inputs From f278c5c25bc5713f62bbe282eeb723a91762f3a6 Mon Sep 17 00:00:00 2001 From: Shion Date: Thu, 24 Aug 2023 22:00:15 +0900 Subject: [PATCH 21/25] fix plotting docstring --- src/torchmetrics/clustering/mutual_info_score.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index 123052bd4ac..f4388f0a528 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -54,8 +54,8 @@ class MutualInfoScore(Metric): Example: >>> import torch >>> from torchmetrics.clustering import MutualInfoScore - >>> target = torch.tensor([0, 2, 1, 1, 0]) >>> preds = torch.tensor([2, 1, 0, 1, 0]) + >>> target = torch.tensor([0, 2, 1, 1, 0]) >>> mi_score = MutualInfoScore() >>> mi_score(preds, target) tensor(0.5004) @@ -108,8 +108,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> from torchmetrics.clustering import MutualInfoScore >>> metric = MutualInfoScore() >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) - >>> metric.compute() - >>> fig_, ax_ = metric.plot() + >>> fig_, ax_ = metric.plot(metric.compute()) .. plot:: :scale: 75 @@ -120,8 +119,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> metric = MutualInfoScore() >>> for _ in range(10): ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))) - >>> metric.compute() - >>> fig_, ax_ = metric.plot() + >>> fig_, ax_ = metric.plot(metric.compute()) """ return self._plot(val, ax) From c866355203111b55b3c6f13cdb4215e8a3d4e8ca Mon Sep 17 00:00:00 2001 From: Shion Date: Thu, 24 Aug 2023 22:20:44 +0900 Subject: [PATCH 22/25] add paren --- src/torchmetrics/clustering/mutual_info_score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index f4388f0a528..86118daf41c 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -118,7 +118,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> from torchmetrics.clustering import MutualInfoScore >>> metric = MutualInfoScore() >>> for _ in range(10): - ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))) + ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) >>> fig_, ax_ = metric.plot(metric.compute()) """ From ca5ff5fc39d488277d650fe6cf4da5566a9bb9b0 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 25 Aug 2023 15:05:47 +0200 Subject: [PATCH 23/25] fix doc import --- docs/source/clustering/mutual_info_score.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/clustering/mutual_info_score.rst b/docs/source/clustering/mutual_info_score.rst index 39291801ae9..e5adf06eaa9 100644 --- a/docs/source/clustering/mutual_info_score.rst +++ b/docs/source/clustering/mutual_info_score.rst @@ -12,10 +12,10 @@ Mutual Information Score Module Interface ________________ -.. autoclass:: torchmetrics.MutualInfoScore +.. autoclass:: torchmetrics.clustering.MutualInfoScore :exclude-members: update, compute Functional Interface ____________________ -.. autofunction:: torchmetrics.functional.mutual_info_score +.. autofunction:: torchmetrics.functional.clustering.mutual_info_score From 157e8f833bd02c50fbcdba0ca1d2b4ddb4d9e7ba Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 25 Aug 2023 15:10:54 +0200 Subject: [PATCH 24/25] fix on gpu --- src/torchmetrics/functional/clustering/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py index 7f685e8cbcc..64dff0377ee 100644 --- a/src/torchmetrics/functional/clustering/utils.py +++ b/src/torchmetrics/functional/clustering/utils.py @@ -63,7 +63,7 @@ def calculate_contingency_matrix( preds_idx, ) ), - torch.ones(target_idx.size(0)), + torch.ones(target_idx.shape[0], dtype=preds_idx.dtype, device=preds_idx.device), ( n_classes_target, n_classes_preds, From 51d3f2a42f18b60d45884268ba4c4d6931ca58de Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Fri, 25 Aug 2023 16:37:40 +0200 Subject: [PATCH 25/25] remove unused arg --- tests/unittests/clustering/test_mutual_info_score.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index 5c89b58b169..c4e0e56f38e 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -56,9 +56,8 @@ class TestMutualInfoScore(MetricTester): atol = 1e-5 - @pytest.mark.parametrize("compute_on_cpu", [True, False]) @pytest.mark.parametrize("ddp", [True, False]) - def test_mutual_info_score(self, preds, target, compute_on_cpu, ddp): + def test_mutual_info_score(self, preds, target, ddp): """Test class implementation of metric.""" self.run_class_metric_test( ddp=ddp, @@ -79,7 +78,7 @@ def test_mutual_info_score_functional(self, preds, target): def test_mutual_info_score_functional_single_cluster(): - """Check that metric rejects continuous-valued inputs.""" + """Check that for single cluster the metric returns 0.""" tensor_a = torch.randint(NUM_CLASSES, (BATCH_SIZE,)) tensor_b = torch.zeros(BATCH_SIZE, dtype=torch.int) assert torch.allclose(mutual_info_score(tensor_a, tensor_b), torch.tensor(0.0))