Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Cramer's V (Cramer's Phi) #1298

Merged
merged 55 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
82dac55
Add Cramer's V metric
stancld Oct 28, 2022
ba6cf16
Merge branch 'master' into metric/cramers-phi
stancld Oct 28, 2022
4418797
Add example for matrix metric
stancld Oct 28, 2022
6b4ea3a
Clean tests
stancld Oct 28, 2022
08a7291
Format matrix metric doc
stancld Oct 28, 2022
092b20e
Clean docs
stancld Oct 28, 2022
b5c7f18
Fix random seed for doctest
stancld Oct 28, 2022
5a58263
Remove redundant ) in module metric docstring
stancld Oct 28, 2022
3069f7d
Fix docs
stancld Oct 28, 2022
9b68bc9
Delete redundant mypy ignore
stancld Oct 28, 2022
1b0ada9
Conditional import and tests if dython available
stancld Oct 28, 2022
cd6328e
Merge branch 'master' into metric/cramers-phi
stancld Oct 28, 2022
7724cf9
Skip last test as well and replace torch.nan with float(nan)
stancld Oct 28, 2022
0b31253
Apply suggestions from code review
Borda Oct 28, 2022
a2c0e0b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2022
f8d52be
import dythom
Borda Oct 28, 2022
a235430
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2022
7f7be24
add new requirement file to devel.txt
justusschock Oct 31, 2022
1bfec24
remove dython checks
justusschock Oct 31, 2022
ac8b555
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2022
e58aa1e
forgot to remove import
justusschock Oct 31, 2022
977238c
Merge branch 'master' into metric/cramers-phi
stancld Oct 31, 2022
c340398
Merge branch 'master' into metric/cramers-phi
SkafteNicki Oct 31, 2022
e217801
Merge branch 'master' into metric/cramers-phi
SkafteNicki Nov 4, 2022
5265751
Merge branch 'master' into metric/cramers-phi
SkafteNicki Nov 5, 2022
5f33b02
Apply suggestsions from code review + Fix a few typos + Update chlog
stancld Nov 6, 2022
9b52fde
Refactor input validation for the module metric
stancld Nov 6, 2022
fbeee68
Update docs
stancld Nov 6, 2022
1d34189
Fix doc example for matrix calculation
stancld Nov 6, 2022
aeb34a2
Try drop full_state_update = False
stancld Nov 6, 2022
89739fd
Update matrix test
stancld Nov 6, 2022
9abb275
Bump up pandas version for oldest config test suite
stancld Nov 6, 2022
f8ddfc5
Skip tests with pandas<1.3.2
stancld Nov 6, 2022
0947c58
Fix pytest.mark.skipif spec
stancld Nov 6, 2022
34e191a
Skip matrix test for oldest as well
stancld Nov 6, 2022
ffea71b
Merge branch 'master' into metric/cramers-phi
stancld Nov 8, 2022
2844236
move aggregation
SkafteNicki Nov 8, 2022
9082ea4
Merge branch 'master' into metric/cramers-phi
mergify[bot] Nov 8, 2022
20038be
Merge branch 'master' into metric/cramers-phi
mergify[bot] Nov 8, 2022
b2e52f1
Add missing return to docstring
stancld Nov 8, 2022
d8c3fa9
dython
Borda Nov 8, 2022
c5379be
Merge branch 'master' into metric/cramers-phi
mergify[bot] Nov 8, 2022
ed25850
Apply suggestions from code review
Borda Nov 8, 2022
ad06101
retrigger tests
SkafteNicki Nov 9, 2022
60cc052
lower atol
SkafteNicki Nov 9, 2022
5c3c8ab
try longer testing
SkafteNicki Nov 9, 2022
0f941a9
fix pandas
SkafteNicki Nov 10, 2022
7e4ab8f
Merge branch 'master' into metric/cramers-phi
mergify[bot] Nov 10, 2022
7471f28
Use pandas>=1.3.2
stancld Nov 10, 2022
a1bce56
trying 1.4.0
SkafteNicki Nov 10, 2022
bac9d70
Merge branch 'master' into metric/cramers-phi
mergify[bot] Nov 11, 2022
f689e54
Try to hack reqs
stancld Nov 13, 2022
4d130dd
Try to hack reqs
stancld Nov 13, 2022
feb1f71
Try to skip tests on GPU
stancld Nov 13, 2022
c23b0de
Skip matrix test + unpin pandas due to numpy mismatch
stancld Nov 13, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `TotalVariation` to image package ([#978](https://github.com/Lightning-AI/metrics/pull/978))


- Added option to pass `distributed_available_fn` to metrics to allow checks for custom communication backend for making `dist_sync_fn` actually useful ([#1301](https://github.com/Lightning-AI/metrics/pull/1301))


Expand All @@ -22,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `LogCoshError` to regression package ([#1316](https://github.com/Lightning-AI/metrics/pull/1316))


- Added `CramersV` to the new nominal package ([#1298](https://github.com/Lightning-AI/metrics/pull/1298))


### Changed

- Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259))
Expand Down
30 changes: 19 additions & 11 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ Or directly from conda
pages/lightning
pages/retrieval

.. toctree::
:maxdepth: 2
:name: aggregation
:caption: Aggregation
:glob:

aggregation/*

.. toctree::
:maxdepth: 2
:name: audio
Expand All @@ -150,6 +158,14 @@ Or directly from conda

classification/*

.. toctree::
:maxdepth: 2
:name: detection
:caption: Detection
:glob:

detection/*

.. toctree::
:maxdepth: 2
:name: image
Expand All @@ -160,11 +176,11 @@ Or directly from conda

.. toctree::
:maxdepth: 2
:name: detection
:caption: Detection
:name: nominal
:caption: Nominal
:glob:

detection/*
nominal/*

.. toctree::
:maxdepth: 2
Expand Down Expand Up @@ -198,14 +214,6 @@ Or directly from conda

text/*

.. toctree::
:maxdepth: 2
:name: aggregation
:caption: Aggregation
:glob:

aggregation/*

.. toctree::
:maxdepth: 2
:name: wrappers
Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
.. _AB divergence: https://pdfs.semanticscholar.org/744b/1166de34cb099100f151f3b1459f141ae25b.pdf
.. _Rényi divergence: https://static.renyi.hu/renyi_cikkek/1961_on_measures_of_entropy_and_information.pdf
.. _Fisher-Rao distance: http://www.scholarpedia.org/article/Fisher-Rao_metric
.. _Cramer's V: https://en.wikipedia.org/wiki/Cram%C3%A9r%27s_V
.. _Kendall Rank Correlation Coefficient: https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient
.. _The Treatment of Ties in Ranking Problems: https://www.jstor.org/stable/2332303
.. _LogCosh Error: https://arxiv.org/pdf/2101.10427.pdf
26 changes: 26 additions & 0 deletions docs/source/nominal/cramers_v.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
.. customcarditem::
:header: Cramer's V
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Nominal

##########
Cramer's V
##########

Module Interface
________________

.. autoclass:: torchmetrics.CramersV
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.cramers_v
:noindex:

cramers_v_matrix
^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.nominal.cramers_v_matrix
:noindex:
1 change: 1 addition & 0 deletions requirements/devel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
-r audio_test.txt
-r detection_test.txt
-r classification_test.txt
-r nominal_test.txt
1 change: 1 addition & 0 deletions requirements/nominal_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
dython
Borda marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
UniversalImageQualityIndex,
)
from torchmetrics.metric import Metric # noqa: E402
from torchmetrics.nominal import CramersV # noqa: E402
from torchmetrics.regression import ( # noqa: E402
ConcordanceCorrCoef,
CosineSimilarity,
Expand Down Expand Up @@ -121,6 +122,7 @@
"CohenKappa",
"ConfusionMatrix",
"CosineSimilarity",
"CramersV",
"Dice",
"TweedieDevianceScore",
"ErrorRelativeGlobalDimensionlessSynthesis",
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from torchmetrics.functional.image.tv import total_variation
from torchmetrics.functional.image.uqi import universal_image_quality_index
from torchmetrics.functional.nominal.cramers import cramers_v
from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity
from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance
from torchmetrics.functional.pairwise.linear import pairwise_linear_similarity
Expand Down Expand Up @@ -102,6 +103,7 @@
"cohen_kappa",
"confusion_matrix",
"cosine_similarity",
"cramers_v",
"tweedie_deviance_score",
"dice_score",
"dice",
Expand Down
14 changes: 14 additions & 0 deletions src/torchmetrics/functional/nominal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix # noqa: F401
221 changes: 221 additions & 0 deletions src/torchmetrics/functional/nominal/cramers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Optional, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.classification.confusion_matrix import _multiclass_confusion_matrix_update
from torchmetrics.functional.nominal.utils import _handle_nan_in_data
from torchmetrics.utilities.prints import rank_zero_warn


def _cramers_input_validation(nan_strategy: str, nan_replace_value: Optional[Union[int, float]]) -> None:
if nan_strategy not in ["replace", "drop"]:
raise ValueError(
f"Argument `nan_strategy` is expected to be one of `['replace', 'drop']`, but got {nan_strategy}"
)
if nan_strategy == "replace" and not isinstance(nan_replace_value, (int, float)):
raise ValueError(
"Argument `nan_replace` is expected to be of a type `int` or `float` when `nan_strategy = 'replace`, "
f"but got {nan_replace_value}"
)


def _compute_expected_freqs(confmat: Tensor) -> Tensor:
"""Compute the expected frequenceis from the provided confusion matrix."""
margin_sum_rows, margin_sum_cols = confmat.sum(1), confmat.sum(0)
expected_freqs = torch.einsum("r, c -> rc", margin_sum_rows, margin_sum_cols) / confmat.sum()
return expected_freqs


def _compute_chi_squared(confmat: Tensor, bias_correction: bool) -> Tensor:
"""Chi-square test of independenc of variables in a confusion matrix table.

Adapted from: https://github.com/scipy/scipy/blob/v1.9.2/scipy/stats/contingency.py.
"""
expected_freqs = _compute_expected_freqs(confmat)
# Get degrees of freedom
df = expected_freqs.numel() - sum(expected_freqs.shape) + expected_freqs.ndim - 1
if df == 0:
return torch.tensor(0.0, device=confmat.device)

if df == 1 and bias_correction:
diff = expected_freqs - confmat
direction = diff.sign()
confmat += direction * torch.minimum(0.5 * torch.ones_like(direction), direction.abs())

return torch.sum((confmat - expected_freqs) ** 2 / expected_freqs)


def _drop_empty_rows_and_cols(confmat: Tensor) -> Tensor:
"""Drop all rows and columns containing only zeros."""
confmat = confmat[confmat.sum(1) != 0]
confmat = confmat[:, confmat.sum(0) != 0]
return confmat


def _cramers_v_update(
preds: Tensor,
target: Tensor,
num_classes: int,
nan_strategy: Literal["replace", "drop"] = "replace",
nan_replace_value: Optional[Union[int, float]] = 0.0,
) -> Tensor:
"""Computes the bins to update the confusion matrix with for Cramer's V calculation.

Args:
preds: 1D or 2D tensor of categorical (nominal) data
target: 1D or 2D tensor of categorical (nominal) data
num_classes: Integer specifing the number of classes
nan_strategy: Indication of whether to replace or drop ``NaN`` values
nan_replace_value: Value to replace ``NaN`s when ``nan_strategy = 'replace```

Returns:
Non-reduced confusion matrix
"""
preds = preds.argmax(1) if preds.ndim == 2 else preds
target = target.argmax(1) if target.ndim == 2 else target
preds, target = _handle_nan_in_data(preds, target, nan_strategy, nan_replace_value)
return _multiclass_confusion_matrix_update(preds, target, num_classes)


def _cramers_v_compute(confmat: Tensor, bias_correction: bool) -> Tensor:
"""Compute Cramers' V statistic based on a pre-computed confusion matrix.

Args:
confmat: Confusion matrix for observed data
bias_correction: Indication of whether to use bias correction.

Returns:
Cramer's V statistic
"""
confmat = _drop_empty_rows_and_cols(confmat)
cm_sum = confmat.sum()
chi_squared = _compute_chi_squared(confmat, bias_correction)
phi_squared = chi_squared / cm_sum
n_rows, n_cols = confmat.shape

if bias_correction:
phi_squared_corrected = torch.max(
torch.tensor(0.0, device=confmat.device), phi_squared - ((n_rows - 1) * (n_cols - 1)) / (cm_sum - 1)
)
rows_corrected = n_rows - (n_rows - 1) ** 2 / (cm_sum - 1)
cols_corrected = n_cols - (n_cols - 1) ** 2 / (cm_sum - 1)
if min(rows_corrected, cols_corrected) == 1:
rank_zero_warn(
"Unable to compute Cramer's V using bias correction. Please consider to set `bias_correction=False`."
)
return torch.tensor(float("nan"), device=confmat.device)
cramers_v_value = torch.sqrt(phi_squared_corrected / min(rows_corrected - 1, cols_corrected - 1))
else:
cramers_v_value = torch.sqrt(phi_squared / min(n_rows - 1, n_cols - 1))
return cramers_v_value.clamp(0.0, 1.0)


def cramers_v(
preds: Tensor,
target: Tensor,
bias_correction: bool = True,
nan_strategy: Literal["replace", "drop"] = "replace",
nan_replace_value: Optional[Union[int, float]] = 0.0,
) -> Tensor:
r"""Compute `Cramer's V`_ statistic measuring the association between two categorical (nominal) data series.

.. math::
V = \sqrt{\frac{\chi^2 / 2}{\min(r - 1, k - 1)}}

where

.. math::
\chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}}

Cramer's V is a symmetric coefficient, i.e.

.. math::
V(preds, target) = V(target, preds)

The output values lies in [0, 1].

Args:
preds: 1D or 2D tensor of categorical (nominal) data
- 1D shape: (batch_size,)
- 2D shape: (batch_size, num_classes)
target: 1D or 2D tensor of categorical (nominal) data
- 1D shape: (batch_size,)
- 2D shape: (batch_size, num_classes)
bias_correction: Indication of whether to use bias correction.
nan_strategy: Indication of whether to replace or drop ``NaN`` values
nan_replace_value: Value to replace ``NaN``s when ``nan_strategy = 'replace'``

Returns:
Cramer's V statistic

Example:
>>> from torchmetrics.functional import cramers_v
>>> _ = torch.manual_seed(42)
>>> preds = torch.randint(0, 4, (100,))
>>> target = torch.round(preds + torch.randn(100)).clamp(0, 4)
>>> cramers_v(preds, target)
tensor(0.5284)
"""
num_classes = len(torch.cat([preds, target]).unique())
confmat = _cramers_v_update(preds, target, num_classes, nan_strategy, nan_replace_value)
return _cramers_v_compute(confmat, bias_correction)


def cramers_v_matrix(
matrix: Tensor,
bias_correction: bool = True,
nan_strategy: Literal["replace", "drop"] = "replace",
nan_replace_value: Optional[Union[int, float]] = 0.0,
) -> Tensor:
r"""Compute `Cramer's V`_ statistic between a set of multiple variables.

This can serve as a convenient tool to compute Cramer's V statistic for analyses of correlation between categorical
variables in your dataset.

Args:
matrix: A tensor of categorical (nominal) data, where:
- rows represent a number of data points
- columns represent a number of categorical (nominal) features
bias_correction: Indication of whether to use bias correction.
nan_strategy: Indication of whether to replace or drop ``NaN`` values
nan_replace_value: Value to replace ``NaN``s when ``nan_strategy = 'replace'``

Returns:
Cramer's V statistic for a dataset of categorical variables

Example:
>>> from torchmetrics.functional.nominal import cramers_v_matrix
>>> _ = torch.manual_seed(42)
>>> matrix = torch.randint(0, 4, (200, 5))
>>> cramers_v_matrix(matrix)
tensor([[1.0000, 0.0637, 0.0000, 0.0542, 0.1337],
[0.0637, 1.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 1.0000, 0.0000, 0.0649],
[0.0542, 0.0000, 0.0000, 1.0000, 0.1100],
[0.1337, 0.0000, 0.0649, 0.1100, 1.0000]])
"""
_cramers_input_validation(nan_strategy, nan_replace_value)
num_variables = matrix.shape[1]
cramers_v_matrix_value = torch.ones(num_variables, num_variables, device=matrix.device)
for i, j in itertools.combinations(range(num_variables), 2):
x, y = matrix[:, i], matrix[:, j]
num_classes = len(torch.cat([x, y]).unique())
confmat = _cramers_v_update(x, y, num_classes, nan_strategy, nan_replace_value)
cramers_v_matrix_value[i, j] = cramers_v_matrix_value[j, i] = _cramers_v_compute(confmat, bias_correction)
return cramers_v_matrix_value
Loading