diff --git a/CHANGELOG.md b/CHANGELOG.md index 8de1ee274b1..9be496340e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `LogCoshError` to regression package ([#1316](https://github.com/Lightning-AI/metrics/pull/1316)) +- Added `CramersV` to the new nominal package ([#1298](https://github.com/Lightning-AI/metrics/pull/1298)) + + ### Changed - Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259)) diff --git a/docs/source/index.rst b/docs/source/index.rst index 1977f03e303..c54b21d3d39 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -134,6 +134,14 @@ Or directly from conda pages/lightning pages/retrieval +.. toctree:: + :maxdepth: 2 + :name: aggregation + :caption: Aggregation + :glob: + + aggregation/* + .. toctree:: :maxdepth: 2 :name: audio @@ -150,6 +158,14 @@ Or directly from conda classification/* +.. toctree:: + :maxdepth: 2 + :name: detection + :caption: Detection + :glob: + + detection/* + .. toctree:: :maxdepth: 2 :name: image @@ -160,11 +176,11 @@ Or directly from conda .. toctree:: :maxdepth: 2 - :name: detection - :caption: Detection + :name: nominal + :caption: Nominal :glob: - detection/* + nominal/* .. toctree:: :maxdepth: 2 @@ -198,14 +214,6 @@ Or directly from conda text/* -.. toctree:: - :maxdepth: 2 - :name: aggregation - :caption: Aggregation - :glob: - - aggregation/* - .. toctree:: :maxdepth: 2 :name: wrappers diff --git a/docs/source/links.rst b/docs/source/links.rst index f53f4c6ecbb..687577ddcfa 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -93,6 +93,7 @@ .. _AB divergence: https://pdfs.semanticscholar.org/744b/1166de34cb099100f151f3b1459f141ae25b.pdf .. _Rényi divergence: https://static.renyi.hu/renyi_cikkek/1961_on_measures_of_entropy_and_information.pdf .. _Fisher-Rao distance: http://www.scholarpedia.org/article/Fisher-Rao_metric +.. _Cramer's V: https://en.wikipedia.org/wiki/Cram%C3%A9r%27s_V .. _Kendall Rank Correlation Coefficient: https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient .. _The Treatment of Ties in Ranking Problems: https://www.jstor.org/stable/2332303 .. _LogCosh Error: https://arxiv.org/pdf/2101.10427.pdf diff --git a/docs/source/nominal/cramers_v.rst b/docs/source/nominal/cramers_v.rst new file mode 100644 index 00000000000..7e6533a730a --- /dev/null +++ b/docs/source/nominal/cramers_v.rst @@ -0,0 +1,26 @@ +.. customcarditem:: + :header: Cramer's V + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Nominal + +########## +Cramer's V +########## + +Module Interface +________________ + +.. autoclass:: torchmetrics.CramersV + :noindex: + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.cramers_v + :noindex: + +cramers_v_matrix +^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.nominal.cramers_v_matrix + :noindex: diff --git a/requirements/devel.txt b/requirements/devel.txt index 05aa5e61b9a..0c256bc83c8 100644 --- a/requirements/devel.txt +++ b/requirements/devel.txt @@ -16,3 +16,4 @@ -r audio_test.txt -r detection_test.txt -r classification_test.txt +-r nominal_test.txt diff --git a/requirements/nominal_test.txt b/requirements/nominal_test.txt new file mode 100644 index 00000000000..ff0f19472d4 --- /dev/null +++ b/requirements/nominal_test.txt @@ -0,0 +1,2 @@ +pandas # cannot pin version due to numpy version incompatibility +dython # todo: pin version, but some version resolution issue diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index d0d95a29c76..4906f5ee690 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -53,6 +53,7 @@ UniversalImageQualityIndex, ) from torchmetrics.metric import Metric # noqa: E402 +from torchmetrics.nominal import CramersV # noqa: E402 from torchmetrics.regression import ( # noqa: E402 ConcordanceCorrCoef, CosineSimilarity, @@ -121,6 +122,7 @@ "CohenKappa", "ConfusionMatrix", "CosineSimilarity", + "CramersV", "Dice", "TweedieDevianceScore", "ErrorRelativeGlobalDimensionlessSynthesis", diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 74e85e1b16f..60b538056ab 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -42,6 +42,7 @@ ) from torchmetrics.functional.image.tv import total_variation from torchmetrics.functional.image.uqi import universal_image_quality_index +from torchmetrics.functional.nominal.cramers import cramers_v from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance from torchmetrics.functional.pairwise.linear import pairwise_linear_similarity @@ -102,6 +103,7 @@ "cohen_kappa", "confusion_matrix", "cosine_similarity", + "cramers_v", "tweedie_deviance_score", "dice_score", "dice", diff --git a/src/torchmetrics/functional/nominal/__init__.py b/src/torchmetrics/functional/nominal/__init__.py new file mode 100644 index 00000000000..415a7b842d5 --- /dev/null +++ b/src/torchmetrics/functional/nominal/__init__.py @@ -0,0 +1,14 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix # noqa: F401 diff --git a/src/torchmetrics/functional/nominal/cramers.py b/src/torchmetrics/functional/nominal/cramers.py new file mode 100644 index 00000000000..3b3a316d5d7 --- /dev/null +++ b/src/torchmetrics/functional/nominal/cramers.py @@ -0,0 +1,221 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +from typing import Optional, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.classification.confusion_matrix import _multiclass_confusion_matrix_update +from torchmetrics.functional.nominal.utils import _handle_nan_in_data +from torchmetrics.utilities.prints import rank_zero_warn + + +def _cramers_input_validation(nan_strategy: str, nan_replace_value: Optional[Union[int, float]]) -> None: + if nan_strategy not in ["replace", "drop"]: + raise ValueError( + f"Argument `nan_strategy` is expected to be one of `['replace', 'drop']`, but got {nan_strategy}" + ) + if nan_strategy == "replace" and not isinstance(nan_replace_value, (int, float)): + raise ValueError( + "Argument `nan_replace` is expected to be of a type `int` or `float` when `nan_strategy = 'replace`, " + f"but got {nan_replace_value}" + ) + + +def _compute_expected_freqs(confmat: Tensor) -> Tensor: + """Compute the expected frequenceis from the provided confusion matrix.""" + margin_sum_rows, margin_sum_cols = confmat.sum(1), confmat.sum(0) + expected_freqs = torch.einsum("r, c -> rc", margin_sum_rows, margin_sum_cols) / confmat.sum() + return expected_freqs + + +def _compute_chi_squared(confmat: Tensor, bias_correction: bool) -> Tensor: + """Chi-square test of independenc of variables in a confusion matrix table. + + Adapted from: https://github.com/scipy/scipy/blob/v1.9.2/scipy/stats/contingency.py. + """ + expected_freqs = _compute_expected_freqs(confmat) + # Get degrees of freedom + df = expected_freqs.numel() - sum(expected_freqs.shape) + expected_freqs.ndim - 1 + if df == 0: + return torch.tensor(0.0, device=confmat.device) + + if df == 1 and bias_correction: + diff = expected_freqs - confmat + direction = diff.sign() + confmat += direction * torch.minimum(0.5 * torch.ones_like(direction), direction.abs()) + + return torch.sum((confmat - expected_freqs) ** 2 / expected_freqs) + + +def _drop_empty_rows_and_cols(confmat: Tensor) -> Tensor: + """Drop all rows and columns containing only zeros.""" + confmat = confmat[confmat.sum(1) != 0] + confmat = confmat[:, confmat.sum(0) != 0] + return confmat + + +def _cramers_v_update( + preds: Tensor, + target: Tensor, + num_classes: int, + nan_strategy: Literal["replace", "drop"] = "replace", + nan_replace_value: Optional[Union[int, float]] = 0.0, +) -> Tensor: + """Computes the bins to update the confusion matrix with for Cramer's V calculation. + + Args: + preds: 1D or 2D tensor of categorical (nominal) data + target: 1D or 2D tensor of categorical (nominal) data + num_classes: Integer specifing the number of classes + nan_strategy: Indication of whether to replace or drop ``NaN`` values + nan_replace_value: Value to replace ``NaN`s when ``nan_strategy = 'replace``` + + Returns: + Non-reduced confusion matrix + """ + preds = preds.argmax(1) if preds.ndim == 2 else preds + target = target.argmax(1) if target.ndim == 2 else target + preds, target = _handle_nan_in_data(preds, target, nan_strategy, nan_replace_value) + return _multiclass_confusion_matrix_update(preds, target, num_classes) + + +def _cramers_v_compute(confmat: Tensor, bias_correction: bool) -> Tensor: + """Compute Cramers' V statistic based on a pre-computed confusion matrix. + + Args: + confmat: Confusion matrix for observed data + bias_correction: Indication of whether to use bias correction. + + Returns: + Cramer's V statistic + """ + confmat = _drop_empty_rows_and_cols(confmat) + cm_sum = confmat.sum() + chi_squared = _compute_chi_squared(confmat, bias_correction) + phi_squared = chi_squared / cm_sum + n_rows, n_cols = confmat.shape + + if bias_correction: + phi_squared_corrected = torch.max( + torch.tensor(0.0, device=confmat.device), phi_squared - ((n_rows - 1) * (n_cols - 1)) / (cm_sum - 1) + ) + rows_corrected = n_rows - (n_rows - 1) ** 2 / (cm_sum - 1) + cols_corrected = n_cols - (n_cols - 1) ** 2 / (cm_sum - 1) + if min(rows_corrected, cols_corrected) == 1: + rank_zero_warn( + "Unable to compute Cramer's V using bias correction. Please consider to set `bias_correction=False`." + ) + return torch.tensor(float("nan"), device=confmat.device) + cramers_v_value = torch.sqrt(phi_squared_corrected / min(rows_corrected - 1, cols_corrected - 1)) + else: + cramers_v_value = torch.sqrt(phi_squared / min(n_rows - 1, n_cols - 1)) + return cramers_v_value.clamp(0.0, 1.0) + + +def cramers_v( + preds: Tensor, + target: Tensor, + bias_correction: bool = True, + nan_strategy: Literal["replace", "drop"] = "replace", + nan_replace_value: Optional[Union[int, float]] = 0.0, +) -> Tensor: + r"""Compute `Cramer's V`_ statistic measuring the association between two categorical (nominal) data series. + + .. math:: + V = \sqrt{\frac{\chi^2 / 2}{\min(r - 1, k - 1)}} + + where + + .. math:: + \chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}} + + Cramer's V is a symmetric coefficient, i.e. + + .. math:: + V(preds, target) = V(target, preds) + + The output values lies in [0, 1]. + + Args: + preds: 1D or 2D tensor of categorical (nominal) data + - 1D shape: (batch_size,) + - 2D shape: (batch_size, num_classes) + target: 1D or 2D tensor of categorical (nominal) data + - 1D shape: (batch_size,) + - 2D shape: (batch_size, num_classes) + bias_correction: Indication of whether to use bias correction. + nan_strategy: Indication of whether to replace or drop ``NaN`` values + nan_replace_value: Value to replace ``NaN``s when ``nan_strategy = 'replace'`` + + Returns: + Cramer's V statistic + + Example: + >>> from torchmetrics.functional import cramers_v + >>> _ = torch.manual_seed(42) + >>> preds = torch.randint(0, 4, (100,)) + >>> target = torch.round(preds + torch.randn(100)).clamp(0, 4) + >>> cramers_v(preds, target) + tensor(0.5284) + """ + num_classes = len(torch.cat([preds, target]).unique()) + confmat = _cramers_v_update(preds, target, num_classes, nan_strategy, nan_replace_value) + return _cramers_v_compute(confmat, bias_correction) + + +def cramers_v_matrix( + matrix: Tensor, + bias_correction: bool = True, + nan_strategy: Literal["replace", "drop"] = "replace", + nan_replace_value: Optional[Union[int, float]] = 0.0, +) -> Tensor: + r"""Compute `Cramer's V`_ statistic between a set of multiple variables. + + This can serve as a convenient tool to compute Cramer's V statistic for analyses of correlation between categorical + variables in your dataset. + + Args: + matrix: A tensor of categorical (nominal) data, where: + - rows represent a number of data points + - columns represent a number of categorical (nominal) features + bias_correction: Indication of whether to use bias correction. + nan_strategy: Indication of whether to replace or drop ``NaN`` values + nan_replace_value: Value to replace ``NaN``s when ``nan_strategy = 'replace'`` + + Returns: + Cramer's V statistic for a dataset of categorical variables + + Example: + >>> from torchmetrics.functional.nominal import cramers_v_matrix + >>> _ = torch.manual_seed(42) + >>> matrix = torch.randint(0, 4, (200, 5)) + >>> cramers_v_matrix(matrix) + tensor([[1.0000, 0.0637, 0.0000, 0.0542, 0.1337], + [0.0637, 1.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 1.0000, 0.0000, 0.0649], + [0.0542, 0.0000, 0.0000, 1.0000, 0.1100], + [0.1337, 0.0000, 0.0649, 0.1100, 1.0000]]) + """ + _cramers_input_validation(nan_strategy, nan_replace_value) + num_variables = matrix.shape[1] + cramers_v_matrix_value = torch.ones(num_variables, num_variables, device=matrix.device) + for i, j in itertools.combinations(range(num_variables), 2): + x, y = matrix[:, i], matrix[:, j] + num_classes = len(torch.cat([x, y]).unique()) + confmat = _cramers_v_update(x, y, num_classes, nan_strategy, nan_replace_value) + cramers_v_matrix_value[i, j] = cramers_v_matrix_value[j, i] = _cramers_v_compute(confmat, bias_correction) + return cramers_v_matrix_value diff --git a/src/torchmetrics/functional/nominal/utils.py b/src/torchmetrics/functional/nominal/utils.py new file mode 100644 index 00000000000..b106d5986e8 --- /dev/null +++ b/src/torchmetrics/functional/nominal/utils.py @@ -0,0 +1,48 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple + +import torch +from torch import Tensor +from typing_extensions import Literal + + +def _handle_nan_in_data( + preds: Tensor, + target: Tensor, + nan_strategy: Literal["replace", "drop"] = "replace", + nan_replace_value: Optional[float] = 0.0, +) -> Tuple[Tensor, Tensor]: + """Handle ``NaN`` values in input data. + + If ``nan_strategy = 'replace'``, all ``NaN`` values are replaced with ``nan_replace_value``. + If ``nan_strategy = 'drop'``, all rows containing ``NaN`` in any of two vectors are dropped. + + Args: + preds: 1D tensor of categorical (nominal) data + target: 1D tensor of categorical (nominal) data + nan_strategy: Indication of whether to replace or drop ``NaN`` values + nan_replace_value: Value to replace ``NaN`s when ``nan_strategy = 'replace``` + + Returns: + Updated ``preds`` and ``target`` tensors which contain no ``Nan`` + + Raises: + ValueError: If ``nan_strategy`` is not from ``['replace', 'drop']``. + ValueError: If ``nan_strategy = replace`` and ``nan_replace_value`` is not of a type ``int`` or ``float``. + """ + if nan_strategy == "replace": + return preds.nan_to_num(nan_replace_value), target.nan_to_num(nan_replace_value) + rows_contain_nan = torch.logical_or(preds.isnan(), target.isnan()) + return preds[~rows_contain_nan], target[~rows_contain_nan] diff --git a/src/torchmetrics/nominal/__init__.py b/src/torchmetrics/nominal/__init__.py new file mode 100644 index 00000000000..c4f7d3d7208 --- /dev/null +++ b/src/torchmetrics/nominal/__init__.py @@ -0,0 +1,14 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.nominal.cramers import CramersV # noqa: F401 diff --git a/src/torchmetrics/nominal/cramers.py b/src/torchmetrics/nominal/cramers.py new file mode 100644 index 00000000000..167e788e5c8 --- /dev/null +++ b/src/torchmetrics/nominal/cramers.py @@ -0,0 +1,101 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.nominal.cramers import _cramers_input_validation, _cramers_v_compute, _cramers_v_update +from torchmetrics.metric import Metric + + +class CramersV(Metric): + r"""Compute `Cramer's V`_ statistic measuring the association between two categorical (nominal) data series. + + .. math:: + V = \sqrt{\frac{\chi^2 / 2}{\min(r - 1, k - 1)}} + + where + + .. math:: + \chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}} + + Cramer's V is a symmetric coefficient, i.e. + + .. math:: + V(preds, target) = V(target, preds) + + The output values lies in [0, 1]. + + Args: + num_classes: Integer specifing the number of classes + bias_correction: Indication of whether to use bias correction. + nan_strategy: Indication of whether to replace or drop ``NaN`` values + nan_replace_value: Value to replace ``NaN``s when ``nan_strategy = 'replace'`` + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + Cramer's V statistic + + Example: + >>> from torchmetrics import CramersV + >>> _ = torch.manual_seed(42) + >>> preds = torch.randint(0, 4, (100,)) + >>> target = torch.round(preds + torch.randn(100)).clamp(0, 4) + >>> cramers_v = CramersV(num_classes=5) + >>> cramers_v(preds, target) + tensor(0.5284) + """ + + full_state_update = False + is_differentiable = False + higher_is_better = False + confmat: Tensor + + def __init__( + self, + num_classes: int, + bias_correction: bool = True, + nan_strategy: Literal["replace", "drop"] = "replace", + nan_replace_value: Optional[Union[int, float]] = 0.0, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.num_classes = num_classes + self.bias_correction = bias_correction + + _cramers_input_validation(nan_strategy, nan_replace_value) + self.nan_strategy = nan_strategy + self.nan_replace_value = nan_replace_value + + self.add_state("confmat", torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets. + + Args: + preds: 1D or 2D tensor of categorical (nominal) data + - 1D shape: (batch_size,) + - 2D shape: (batch_size, num_classes) + target: 1D or 2D tensor of categorical (nominal) data + - 1D shape: (batch_size,) + - 2D shape: (batch_size, num_classes) + """ + confmat = _cramers_v_update(preds, target, self.num_classes, self.nan_strategy, self.nan_replace_value) + self.confmat += confmat + + def compute(self) -> Tensor: + """Computer Cramer's V statistic.""" + return _cramers_v_compute(self.confmat, self.bias_correction) diff --git a/tests/unittests/nominal/__init__.py b/tests/unittests/nominal/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unittests/nominal/test_cramers.py b/tests/unittests/nominal/test_cramers.py new file mode 100644 index 00000000000..5746ce6cd34 --- /dev/null +++ b/tests/unittests/nominal/test_cramers.py @@ -0,0 +1,177 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import operator +from collections import namedtuple +from functools import partial + +import pytest +import torch +from dython.nominal import cramers_v as dython_cramers_v + +from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix +from torchmetrics.nominal.cramers import CramersV +from torchmetrics.utilities.imports import _compare_version +from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester + +Input = namedtuple("Input", ["preds", "target"]) +NUM_CLASSES = 4 + +_input_default = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +# Requires float type to pass NaNs +_preds = torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.float) +_preds[0, 0] = float("nan") +_preds[-1, -1] = float("nan") +_target = torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.float) +_target[1, 0] = float("nan") +_target[-1, 0] = float("nan") +_input_with_nans = Input(preds=_preds, target=_target) + +_input_logits = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), target=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) +) + + +@pytest.fixture +def _matrix_input(): + matrix = torch.cat( + [ + torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES * BATCH_SIZE, 1), dtype=torch.float), + torch.randint(high=NUM_CLASSES + 2, size=(NUM_BATCHES * BATCH_SIZE, 1), dtype=torch.float), + torch.randint(high=2, size=(NUM_BATCHES * BATCH_SIZE, 1), dtype=torch.float), + ], + dim=-1, + ) + matrix[0, 0] = float("nan") + matrix[-1, -1] = float("nan") + return matrix + + +def _dython_cramers_v(preds, target, bias_correction, nan_strategy, nan_replace_value): + preds = preds.argmax(1) if preds.ndim == 2 else preds + target = target.argmax(1) if target.ndim == 2 else target + + v = dython_cramers_v( + preds.numpy(), + target.numpy(), + bias_correction=bias_correction, + nan_strategy=nan_strategy, + nan_replace_value=nan_replace_value, + ) + return torch.tensor(v) + + +def _dython_cramers_v_matrix(matrix, bias_correction, nan_strategy, nan_replace_value): + num_variables = matrix.shape[1] + cramers_v_matrix_value = torch.ones(num_variables, num_variables) + for i, j in itertools.combinations(range(num_variables), 2): + x, y = matrix[:, i], matrix[:, j] + cramers_v_matrix_value[i, j] = cramers_v_matrix_value[j, i] = _dython_cramers_v( + x, y, bias_correction, nan_strategy, nan_replace_value + ) + return cramers_v_matrix_value + + +@pytest.mark.skipif( + _compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`" +) +@pytest.mark.skipif( # TODO: testing on CUDA fails with pandas 1.3.5, and newer is not available for python 3.7 + torch.cuda.is_available(), reason="Tests fail on CUDA with the most up-to-date available pandas" +) +@pytest.mark.parametrize( + "preds, target", + [ + (_input_default.preds, _input_default.target), + (_input_with_nans.preds, _input_with_nans.target), + (_input_logits.preds, _input_logits.target), + ], +) +@pytest.mark.parametrize("bias_correction", [False, True]) +@pytest.mark.parametrize("nan_strategy, nan_replace_value", [("replace", 0.0), ("drop", None)]) +class TestCramersV(MetricTester): + atol = 1e-5 + + @pytest.mark.parametrize("ddp", [False, True]) + @pytest.mark.parametrize("dist_sync_on_step", [False, True]) + def test_cramers_v(self, ddp, dist_sync_on_step, preds, target, bias_correction, nan_strategy, nan_replace_value): + metric_args = { + "bias_correction": bias_correction, + "nan_strategy": nan_strategy, + "nan_replace_value": nan_replace_value, + "num_classes": NUM_CLASSES, + } + reference_metric = partial( + _dython_cramers_v, + bias_correction=bias_correction, + nan_strategy=nan_strategy, + nan_replace_value=nan_replace_value, + ) + self.run_class_metric_test( + ddp=ddp, + dist_sync_on_step=dist_sync_on_step, + preds=preds, + target=target, + metric_class=CramersV, + sk_metric=reference_metric, + metric_args=metric_args, + ) + + def test_cramers_v_functional(self, preds, target, bias_correction, nan_strategy, nan_replace_value): + metric_args = { + "bias_correction": bias_correction, + "nan_strategy": nan_strategy, + "nan_replace_value": nan_replace_value, + } + reference_metric = partial( + _dython_cramers_v, + bias_correction=bias_correction, + nan_strategy=nan_strategy, + nan_replace_value=nan_replace_value, + ) + self.run_functional_metric_test( + preds, target, metric_functional=cramers_v, sk_metric=reference_metric, metric_args=metric_args + ) + + def test_cramers_v_differentiability(self, preds, target, bias_correction, nan_strategy, nan_replace_value): + metric_args = { + "bias_correction": bias_correction, + "nan_strategy": nan_strategy, + "nan_replace_value": nan_replace_value, + "num_classes": NUM_CLASSES, + } + self.run_differentiability_test( + preds, + target, + metric_module=CramersV, + metric_functional=cramers_v, + metric_args=metric_args, + ) + + +@pytest.mark.skipif( + _compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`" +) +@pytest.mark.skipif( # TODO: testing on CUDA fails with pandas 1.3.5, and newer is not available for python 3.7 + torch.cuda.is_available(), reason="Tests fail on CUDA with the most up-to-date available pandas" +) +@pytest.mark.parametrize("bias_correction", [False, True]) +@pytest.mark.parametrize("nan_strategy, nan_replace_value", [("replace", 1.0), ("drop", None)]) +def test_cramers_v_matrix(_matrix_input, bias_correction, nan_strategy, nan_replace_value): + tm_score = cramers_v_matrix(_matrix_input, bias_correction, nan_strategy, nan_replace_value) + reference_score = _dython_cramers_v_matrix(_matrix_input, bias_correction, nan_strategy, nan_replace_value) + assert torch.allclose(tm_score, reference_score)