From bef0ff3fc4a791c94f72e5bbf7455a84e69eaf61 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 1 Dec 2023 16:03:36 +0100 Subject: [PATCH 01/12] add doc pages --- docs/source/retrieval/auroc.rst | 21 +++++++++++++++++++ .../retrieval/precision_recall_curve.rst | 8 +++---- 2 files changed, 25 insertions(+), 4 deletions(-) create mode 100644 docs/source/retrieval/auroc.rst diff --git a/docs/source/retrieval/auroc.rst b/docs/source/retrieval/auroc.rst new file mode 100644 index 00000000000..7890b3bf248 --- /dev/null +++ b/docs/source/retrieval/auroc.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Retrieval AUROC + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/text_classification.svg + :tags: Retrieval + +.. include:: ../links.rst + +############### +Retrieval AUROC +############### + +Module Interface +________________ + +.. autoclass:: torchmetrics.retrieval.RetrievalAUROC + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.retrieval.retrieval_auroc diff --git a/docs/source/retrieval/precision_recall_curve.rst b/docs/source/retrieval/precision_recall_curve.rst index f3c77522954..bb976e9e9ef 100644 --- a/docs/source/retrieval/precision_recall_curve.rst +++ b/docs/source/retrieval/precision_recall_curve.rst @@ -1,13 +1,13 @@ .. customcarditem:: - :header: Precision Recall Curve + :header: Retrieval Precision Recall Curve :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/text_classification.svg :tags: Retrieval .. include:: ../links.rst -###################### -Precision Recall Curve -###################### +################################ +Retrieval Precision Recall Curve +################################ Module Interface ________________ From f4b0f19d0c3c228c558e5850907b3bb7c082a8c6 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 1 Dec 2023 16:04:12 +0100 Subject: [PATCH 02/12] add init files --- src/torchmetrics/functional/retrieval/__init__.py | 3 ++- src/torchmetrics/retrieval/__init__.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/retrieval/__init__.py b/src/torchmetrics/functional/retrieval/__init__.py index 97063abbb2f..069b6a3089c 100644 --- a/src/torchmetrics/functional/retrieval/__init__.py +++ b/src/torchmetrics/functional/retrieval/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from torchmetrics.functional.retrieval.auroc import retrieval_auroc from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate @@ -23,6 +23,7 @@ from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank __all__ = [ + "retrieval_auroc", "retrieval_average_precision", "retrieval_fall_out", "retrieval_hit_rate", diff --git a/src/torchmetrics/retrieval/__init__.py b/src/torchmetrics/retrieval/__init__.py index ec80002da9a..18a9df576af 100644 --- a/src/torchmetrics/retrieval/__init__.py +++ b/src/torchmetrics/retrieval/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.retrieval.auroc import RetrievalAUROC from torchmetrics.retrieval.average_precision import RetrievalMAP from torchmetrics.retrieval.fall_out import RetrievalFallOut from torchmetrics.retrieval.hit_rate import RetrievalHitRate @@ -22,6 +23,7 @@ from torchmetrics.retrieval.reciprocal_rank import RetrievalMRR __all__ = [ + "RetrievalAUROC", "RetrievalFallOut", "RetrievalHitRate", "RetrievalMAP", From 44eac59b418c7c729fc97f455352b4932689a123 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 1 Dec 2023 16:08:17 +0100 Subject: [PATCH 03/12] functional implementation --- .../functional/retrieval/auroc.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 src/torchmetrics/functional/retrieval/auroc.py diff --git a/src/torchmetrics/functional/retrieval/auroc.py b/src/torchmetrics/functional/retrieval/auroc.py new file mode 100644 index 00000000000..33c88643ffa --- /dev/null +++ b/src/torchmetrics/functional/retrieval/auroc.py @@ -0,0 +1,64 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from torch import Tensor, tensor + +from torchmetrics.functional.classification.auroc import binary_auroc +from torchmetrics.utilities.checks import _check_retrieval_functional_inputs + + +def retrieval_auroc( + preds: Tensor, target: Tensor, top_k: Optional[int] = None, max_fpr: Optional[float] = None +) -> Tensor: + """Compute area under the receiver operating characteristic curve (AUROC) for information retrieval. + + ``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``, + ``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be ``float``, + otherwise an error is raised. + + Args: + preds: estimated probabilities of each document to be relevant. + target: ground truth about each document being relevant or not. + top_k: consider only the top k elements (default: ``None``, which considers them all) + max_fpr: If not ``None``, calculates standardized partial AUC over the range ``[0, max_fpr]``. + + Return: + a single-value tensor with the auroc value of the predictions ``preds`` w.r.t. the labels ``target``. + + Raises: + ValueError: + If ``top_k`` is not ``None`` or an integer larger than 0. + + Example: + >>> from torchmetrics.functional.retrieval import retrieval_auroc + >>> preds = tensor([0.2, 0.3, 0.5]) + >>> target = tensor([True, False, True]) + >>> retrieval_auroc(preds, target) + tensor(0.5000) + + """ + preds, target = _check_retrieval_functional_inputs(preds, target) + + top_k = top_k or preds.shape[-1] + if not isinstance(top_k, int) and top_k <= 0: + raise ValueError(f"Argument ``top_k`` has to be a positive integer or None, but got {top_k}.") + + top_k_idx = preds.topk(min(top_k, preds.shape[-1]), sorted=True, dim=-1)[1] + target = target[top_k_idx] + if not target.sum(): + return tensor(0.0, device=preds.device) + + preds = preds[top_k_idx] + return binary_auroc(preds, target, max_fpr=max_fpr) From 813cc7e7869d348bc076d3a8fe9f19cba8bb73e1 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 1 Dec 2023 16:12:30 +0100 Subject: [PATCH 04/12] Add RetrievalAUROC metric to torchmetrics --- src/torchmetrics/retrieval/auroc.py | 160 ++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 src/torchmetrics/retrieval/auroc.py diff --git a/src/torchmetrics/retrieval/auroc.py b/src/torchmetrics/retrieval/auroc.py new file mode 100644 index 00000000000..f0f7245fe08 --- /dev/null +++ b/src/torchmetrics/retrieval/auroc.py @@ -0,0 +1,160 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, Sequence, Union + +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.classification.auroc import binary_auroc +from torchmetrics.retrieval.base import RetrievalMetric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["RetrievalMAP.plot"] + + +class RetrievalAUROC(RetrievalMetric): + """Compute area under the receiver operating characteristic curve (AUROC) for information retrieval. + + Works with binary target data. Accepts float predictions from a model output. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)`` + - ``target`` (:class:`~torch.Tensor`): A long or bool tensor of shape ``(N, ...)`` + - ``indexes`` (:class:`~torch.Tensor`): A long tensor of shape ``(N, ...)`` which indicate to which query a + prediction belongs + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``auroc@k`` (:class:`~torch.Tensor`): A single-value tensor with the auroc value + of the predictions ``preds`` w.r.t. the labels ``target``. + + All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning, + so that for example, a tensor of shape ``(N, M)`` is treated as ``(N * M, )``. Predictions will be first grouped by + ``indexes`` and then will be computed as the mean of the metric over each query. + + Args: + empty_target_action: + Specify what to do with queries that do not have at least a positive ``target``. Choose from: + + - ``'neg'``: those queries count as ``0.0`` (default) + - ``'pos'``: those queries count as ``1.0`` + - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned + - ``'error'``: raise a ``ValueError`` + + ignore_index: Ignore predictions where the target is equal to this number. + top_k: Consider only the top k elements for each query (default: ``None``, which considers them all) + max_fpr: If not ``None``, calculates standardized partial AUC over the range ``[0, max_fpr]``. + aggregation: + Specify how to aggregate over indexes. Can either a custom callable function that takes in a single tensor + and returns a scalar value or one of the following strings: + + - ``'mean'``: average value is returned + - ``'median'``: median value is returned + - ``'max'``: max value is returned + - ``'min'``: min value is returned + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ValueError: + If ``empty_target_action`` is not one of ``error``, ``skip``, ``neg`` or ``pos``. + ValueError: + If ``ignore_index`` is not `None` or an integer. + ValueError: + If ``top_k`` is not ``None`` or not an integer greater than 0. + + Example: + >>> from torch import tensor + >>> from torchmetrics.retrieval import RetrievalAUROC + >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) + >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) + >>> target = tensor([False, False, True, False, True, False, True]) + >>> rmap = RetrievalAUROC() + >>> rmap(preds, target, indexes=indexes) + tensor(0.7500) + + """ + + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + def __init__( + self, + empty_target_action: str = "neg", + ignore_index: Optional[int] = None, + max_fpr: Optional[float] = None, + aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean", + **kwargs: Any, + ) -> None: + super().__init__( + empty_target_action=empty_target_action, + ignore_index=ignore_index, + aggregation=aggregation, + **kwargs, + ) + + if max_fpr is not None and not isinstance(max_fpr, float) and 0 < max_fpr <= 1: + raise ValueError(f"Arguments `max_fpr` should be a float in range (0, 1], but got: {max_fpr}") + self.max_fpr = max_fpr + + def _metric(self, preds: Tensor, target: Tensor) -> Tensor: + return binary_auroc(preds, target, max_fpr=self.max_fpr) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> import torch + >>> from torchmetrics.retrieval import RetrievalAUROC + >>> # Example plotting a single value + >>> metric = RetrievalAUROC() + >>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> import torch + >>> from torchmetrics.retrieval import RetrievalAUROC + >>> # Example plotting multiple values + >>> metric = RetrievalAUROC() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))) + >>> fig, ax = metric.plot(values) + + """ + return self._plot(val, ax) From e213990313b355480ab74287bebf2255b6db04ba Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 1 Dec 2023 16:42:31 +0100 Subject: [PATCH 05/12] Refactor retrieval AUROC metric to support top-k predictions --- src/torchmetrics/functional/retrieval/auroc.py | 6 +++--- src/torchmetrics/retrieval/auroc.py | 9 ++++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/retrieval/auroc.py b/src/torchmetrics/functional/retrieval/auroc.py index 33c88643ffa..186e499159e 100644 --- a/src/torchmetrics/functional/retrieval/auroc.py +++ b/src/torchmetrics/functional/retrieval/auroc.py @@ -57,8 +57,8 @@ def retrieval_auroc( top_k_idx = preds.topk(min(top_k, preds.shape[-1]), sorted=True, dim=-1)[1] target = target[top_k_idx] - if not target.sum(): - return tensor(0.0, device=preds.device) + if (0 not in target) or (1 not in target): + return tensor(0.0, device=preds.device, dtype=preds.dtype) preds = preds[top_k_idx] - return binary_auroc(preds, target, max_fpr=max_fpr) + return binary_auroc(preds, target.int(), max_fpr=max_fpr) diff --git a/src/torchmetrics/retrieval/auroc.py b/src/torchmetrics/retrieval/auroc.py index f0f7245fe08..ceb46e37e61 100644 --- a/src/torchmetrics/retrieval/auroc.py +++ b/src/torchmetrics/retrieval/auroc.py @@ -16,7 +16,7 @@ from torch import Tensor from typing_extensions import Literal -from torchmetrics.functional.classification.auroc import binary_auroc +from torchmetrics.functional.retrieval.auroc import retrieval_auroc from torchmetrics.retrieval.base import RetrievalMetric from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE @@ -99,6 +99,7 @@ def __init__( self, empty_target_action: str = "neg", ignore_index: Optional[int] = None, + top_k: Optional[int] = None, max_fpr: Optional[float] = None, aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean", **kwargs: Any, @@ -109,13 +110,15 @@ def __init__( aggregation=aggregation, **kwargs, ) - + if top_k is not None and not (isinstance(top_k, int) and top_k > 0): + raise ValueError("`top_k` has to be a positive integer or None") + self.top_k = top_k if max_fpr is not None and not isinstance(max_fpr, float) and 0 < max_fpr <= 1: raise ValueError(f"Arguments `max_fpr` should be a float in range (0, 1], but got: {max_fpr}") self.max_fpr = max_fpr def _metric(self, preds: Tensor, target: Tensor) -> Tensor: - return binary_auroc(preds, target, max_fpr=self.max_fpr) + return retrieval_auroc(preds, target, top_k=self.top_k, max_fpr=self.max_fpr) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None From 48a554b7960bf64c99fbbd692c9947dd6473c2d0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 1 Dec 2023 16:43:06 +0100 Subject: [PATCH 06/12] Add unit tests for RetrievalAUROC metric --- tests/unittests/retrieval/test_auroc.py | 209 ++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 tests/unittests/retrieval/test_auroc.py diff --git a/tests/unittests/retrieval/test_auroc.py b/tests/unittests/retrieval/test_auroc.py new file mode 100644 index 00000000000..6c6b1951c05 --- /dev/null +++ b/tests/unittests/retrieval/test_auroc.py @@ -0,0 +1,209 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Union + +import numpy as np +import pytest +from sklearn.metrics import roc_auc_score +from torch import Tensor +from torchmetrics.functional.retrieval.auroc import retrieval_auroc +from torchmetrics.retrieval.auroc import RetrievalAUROC +from typing_extensions import Literal + +from unittests.helpers import seed_all +from unittests.retrieval.helpers import ( + RetrievalMetricTester, + _concat_tests, + _custom_aggregate_fn, + _default_metric_class_input_arguments, + _default_metric_class_input_arguments_ignore_index, + _default_metric_functional_input_arguments, + _errors_test_class_metric_parameters_adaptive_k, + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_k, + _errors_test_class_metric_parameters_no_pos_target, + _errors_test_functional_metric_parameters_adaptive_k, + _errors_test_functional_metric_parameters_default, + _errors_test_functional_metric_parameters_k, +) + +seed_all(42) + + +def _auroc_at_k(target: np.ndarray, preds: np.ndarray, top_k: Optional[int] = None, max_fpr: Optional[float] = None): + """Reference implementation using sklearn.""" + assert target.shape == preds.shape + assert len(target.shape) == 1 # works only with single dimension inputs + + if top_k is None or top_k > len(preds): + top_k = len(preds) + idx = np.argsort(preds)[::-1][:top_k] + preds = preds[idx] + target = target[idx] + if (0 not in target) or (1 not in target): + return 0.0 + return roc_auc_score(target, preds, max_fpr=max_fpr) + + +class TestAUROC(RetrievalMetricTester): + """Test class for `RetrievalAUROC` metric.""" + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail + @pytest.mark.parametrize("k", [None, 1, 4, 10]) + @pytest.mark.parametrize("max_fpr", [None, 0.25]) + @pytest.mark.parametrize("aggregation", ["mean", "median", "max", "min", _custom_aggregate_fn]) + @pytest.mark.parametrize(**_default_metric_class_input_arguments) + def test_class_metric( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + empty_target_action: str, + ignore_index: int, + k: int, + max_fpr: Optional[float], + aggregation: Union[Literal["mean", "median", "min", "max"], Callable], + ): + """Test class implementation of metric.""" + metric_args = { + "empty_target_action": empty_target_action, + "top_k": k, + "ignore_index": ignore_index, + "max_fpr": max_fpr, + "aggregation": aggregation, + } + + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalAUROC, + reference_metric=_auroc_at_k, + metric_args=metric_args, + ) + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("k", [None, 1, 4, 10]) + @pytest.mark.parametrize("max_fpr", [False, True]) + @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) + def test_class_metric_ignore_index( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + empty_target_action: str, + k: int, + max_fpr: Optional[float], + ): + """Test class implementation of metric with ignore_index argument.""" + metric_args = { + "empty_target_action": empty_target_action, + "top_k": k, + "ignore_index": -100, + "max_fpr": max_fpr, + } + + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalAUROC, + reference_metric=_auroc_at_k, + metric_args=metric_args, + ) + + @pytest.mark.parametrize(**_default_metric_functional_input_arguments) + @pytest.mark.parametrize("k", [None, 1, 4, 10]) + @pytest.mark.parametrize("max_fpr", [False, True]) + def test_functional_metric(self, preds: Tensor, target: Tensor, k: int, max_fpr: Optional[float]): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=retrieval_auroc, + reference_metric=_auroc_at_k, + metric_args={}, + top_k=k, + max_fpr=max_fpr, + ) + + @pytest.mark.parametrize(**_default_metric_class_input_arguments) + def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor): + """Test dtype support of the metric on CPU.""" + self.run_precision_test_cpu( + indexes=indexes, + preds=preds, + target=target, + metric_module=RetrievalAUROC, + metric_functional=retrieval_auroc, + ) + + @pytest.mark.parametrize(**_default_metric_class_input_arguments) + def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): + """Test dtype support of the metric on GPU.""" + self.run_precision_test_gpu( + indexes=indexes, + preds=preds, + target=target, + metric_module=RetrievalAUROC, + metric_functional=retrieval_auroc, + ) + + @pytest.mark.parametrize( + **_concat_tests( + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_pos_target, + _errors_test_class_metric_parameters_k, + _errors_test_class_metric_parameters_adaptive_k, + ) + ) + def test_arguments_class_metric( + self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict + ): + """Test that specific errors are raised for incorrect input.""" + self.run_metric_class_arguments_test( + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalAUROC, + message=message, + metric_args=metric_args, + exception_type=ValueError, + kwargs_update={}, + ) + + @pytest.mark.parametrize( + **_concat_tests( + _errors_test_functional_metric_parameters_default, + _errors_test_functional_metric_parameters_k, + _errors_test_functional_metric_parameters_adaptive_k, + ) + ) + def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict): + """Test that specific errors are raised for incorrect input.""" + self.run_functional_metric_arguments_test( + preds=preds, + target=target, + metric_functional=retrieval_auroc, + message=message, + exception_type=ValueError, + kwargs_update=metric_args, + ) From 571f2a2dc8b83f1c8fb036992657d4db56841318 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 1 Dec 2023 16:45:54 +0100 Subject: [PATCH 07/12] Add Changelog --- CHANGELOG.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 468b459608b..9c7a4d46e38 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,7 +28,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for logging `MultiTaskWrapper` directly with lightnings `log_dict` method ([#2213](https://github.com/Lightning-AI/torchmetrics/pull/2213)) -- Added `aggregate`` argument to retrieval metrics ([#2220](https://github.com/Lightning-AI/torchmetrics/pull/2220)) +- Added `aggregate` argument to retrieval metrics ([#2220](https://github.com/Lightning-AI/torchmetrics/pull/2220)) + + +- Added `RetrievalAUROC` metric ([#2251](https://github.com/Lightning-AI/torchmetrics/pull/2251)) + ### Changed From 1ed4128e579f56b725e460137147f690f29b9cdd Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 20 Dec 2023 14:45:56 +0100 Subject: [PATCH 08/12] fix spelling when skipping --- src/torchmetrics/retrieval/auroc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/retrieval/auroc.py b/src/torchmetrics/retrieval/auroc.py index ceb46e37e61..3d97afd9039 100644 --- a/src/torchmetrics/retrieval/auroc.py +++ b/src/torchmetrics/retrieval/auroc.py @@ -22,7 +22,7 @@ from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: - __doctest_skip__ = ["RetrievalMAP.plot"] + __doctest_skip__ = ["RetrievalAUROC.plot"] class RetrievalAUROC(RetrievalMetric): From 8baf448c515c232fea0538fb10a43485ad0f76ab Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 20 Dec 2023 15:22:02 +0100 Subject: [PATCH 09/12] fix tests --- src/torchmetrics/functional/retrieval/auroc.py | 4 ++-- tests/unittests/retrieval/test_auroc.py | 4 ---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/retrieval/auroc.py b/src/torchmetrics/functional/retrieval/auroc.py index 186e499159e..e0dc68794db 100644 --- a/src/torchmetrics/functional/retrieval/auroc.py +++ b/src/torchmetrics/functional/retrieval/auroc.py @@ -52,8 +52,8 @@ def retrieval_auroc( preds, target = _check_retrieval_functional_inputs(preds, target) top_k = top_k or preds.shape[-1] - if not isinstance(top_k, int) and top_k <= 0: - raise ValueError(f"Argument ``top_k`` has to be a positive integer or None, but got {top_k}.") + if not (isinstance(top_k, int) and top_k > 0): + raise ValueError("`top_k` has to be a positive integer or None") top_k_idx = preds.topk(min(top_k, preds.shape[-1]), sorted=True, dim=-1)[1] target = target[top_k_idx] diff --git a/tests/unittests/retrieval/test_auroc.py b/tests/unittests/retrieval/test_auroc.py index 6c6b1951c05..33eabfa2320 100644 --- a/tests/unittests/retrieval/test_auroc.py +++ b/tests/unittests/retrieval/test_auroc.py @@ -29,11 +29,9 @@ _default_metric_class_input_arguments, _default_metric_class_input_arguments_ignore_index, _default_metric_functional_input_arguments, - _errors_test_class_metric_parameters_adaptive_k, _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_k, _errors_test_class_metric_parameters_no_pos_target, - _errors_test_functional_metric_parameters_adaptive_k, _errors_test_functional_metric_parameters_default, _errors_test_functional_metric_parameters_k, ) @@ -172,7 +170,6 @@ def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_no_pos_target, _errors_test_class_metric_parameters_k, - _errors_test_class_metric_parameters_adaptive_k, ) ) def test_arguments_class_metric( @@ -194,7 +191,6 @@ def test_arguments_class_metric( **_concat_tests( _errors_test_functional_metric_parameters_default, _errors_test_functional_metric_parameters_k, - _errors_test_functional_metric_parameters_adaptive_k, ) ) def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict): From f888b0b6aa528944310165aae92c48da8dbed784 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 20 Dec 2023 15:24:12 +0100 Subject: [PATCH 10/12] Update max_fpr parameter in test_auroc.py --- tests/unittests/retrieval/test_auroc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/retrieval/test_auroc.py b/tests/unittests/retrieval/test_auroc.py index 33eabfa2320..461c05bec97 100644 --- a/tests/unittests/retrieval/test_auroc.py +++ b/tests/unittests/retrieval/test_auroc.py @@ -98,7 +98,7 @@ def test_class_metric( @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("k", [None, 1, 4, 10]) - @pytest.mark.parametrize("max_fpr", [False, True]) + @pytest.mark.parametrize("max_fpr", [None, 0.25]) @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) def test_class_metric_ignore_index( self, @@ -130,7 +130,7 @@ def test_class_metric_ignore_index( @pytest.mark.parametrize(**_default_metric_functional_input_arguments) @pytest.mark.parametrize("k", [None, 1, 4, 10]) - @pytest.mark.parametrize("max_fpr", [False, True]) + @pytest.mark.parametrize("max_fpr", [None, 0.25]) def test_functional_metric(self, preds: Tensor, target: Tensor, k: int, max_fpr: Optional[float]): """Test functional implementation of metric.""" self.run_functional_metric_test( From ef7da19f38a2fd893b7b93a4eb0835a5dd267542 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 22 Dec 2023 21:40:28 +0100 Subject: [PATCH 11/12] Literal Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> --- src/torchmetrics/retrieval/auroc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/retrieval/auroc.py b/src/torchmetrics/retrieval/auroc.py index 3d97afd9039..dbb2d7310ec 100644 --- a/src/torchmetrics/retrieval/auroc.py +++ b/src/torchmetrics/retrieval/auroc.py @@ -97,7 +97,7 @@ class RetrievalAUROC(RetrievalMetric): def __init__( self, - empty_target_action: str = "neg", + empty_target_action: Literal["error", "skip", "neg", "pos"] = "neg", ignore_index: Optional[int] = None, top_k: Optional[int] = None, max_fpr: Optional[float] = None, From d95ed39f1ecccf50a30d634fb7f8b96d4e0ffa62 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 Dec 2023 20:41:07 +0000 Subject: [PATCH 12/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/retrieval/auroc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/retrieval/auroc.py b/src/torchmetrics/retrieval/auroc.py index dbb2d7310ec..8d5ac12a929 100644 --- a/src/torchmetrics/retrieval/auroc.py +++ b/src/torchmetrics/retrieval/auroc.py @@ -97,7 +97,7 @@ class RetrievalAUROC(RetrievalMetric): def __init__( self, - empty_target_action: Literal["error", "skip", "neg", "pos"] = "neg", + empty_target_action: Literal["error", "skip", "neg", "pos"] = "neg", ignore_index: Optional[int] = None, top_k: Optional[int] = None, max_fpr: Optional[float] = None,