From 67eb71b095d3bd002f8c4df2983070075cf9b7b5 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 29 Jun 2021 00:10:49 +0200 Subject: [PATCH] Replace thredshold argument to binned metrics (#322) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * enhancement * Update CHANGELOG.md * depr * fix * fix tests * docs Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec Co-authored-by: Jirka Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 2 + .../test_binned_precision_recall.py | 34 ++++++++- .../classification/binned_precision_recall.py | 75 +++++++++++++++---- torchmetrics/regression/psnr.py | 4 +- torchmetrics/regression/ssim.py | 4 +- 5 files changed, 98 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2add3513a9f..39b8ca38b50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added Inception Score metric to image module ([#299](https://github.com/PyTorchLightning/metrics/pull/299)) - Added KID metric to image module ([#301](https://github.com/PyTorchLightning/metrics/pull/301)) - Added `sync` and `sync_context` methods for manually controlling when metric states are synced ([#302](https://github.com/PyTorchLightning/metrics/pull/302)) +- Added `thresholds` argument to binned metrics for manually controlling the thresholds ([#322](https://github.com/PyTorchLightning/metrics/pull/322)) - Added `KLDivergence` metric ([#247](https://github.com/PyTorchLightning/metrics/pull/247)) ### Changed @@ -34,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Deprecated - Deprecated `torchmetrics.functional.mean_relative_error` ([#248](https://github.com/PyTorchLightning/metrics/pull/248)) +- Deprecated `num_thresholds` argument in `BinnedPrecisionRecallCurve` ([#322](https://github.com/PyTorchLightning/metrics/pull/322)) ### Removed diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index 3f3b76922ec..980de2cb941 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -28,7 +28,11 @@ from tests.classification.inputs import _input_multilabel_prob_plausible as _input_mlb_prob_ok from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, MetricTester -from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision, BinnedRecallAtFixedPrecision +from torchmetrics.classification.binned_precision_recall import ( + BinnedAveragePrecision, + BinnedPrecisionRecallCurve, + BinnedRecallAtFixedPrecision, +) seed_all(42) @@ -112,8 +116,10 @@ class TestBinnedAveragePrecision(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - @pytest.mark.parametrize("num_thresholds", [101, 301]) - def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, num_thresholds): + @pytest.mark.parametrize( + "num_thresholds, thresholds", ([101, None], [301, None], [None, torch.linspace(0.0, 1.0, 101)]) + ) + def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, num_thresholds, thresholds): # rounding will simulate binning for both implementations preds = Tensor(np.round(preds.numpy(), 2)) + 1e-6 @@ -127,5 +133,27 @@ def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_o metric_args={ "num_classes": num_classes, "num_thresholds": num_thresholds, + "thresholds": thresholds }, ) + + +@pytest.mark.parametrize( + "metric_class", [BinnedAveragePrecision, BinnedRecallAtFixedPrecision, BinnedPrecisionRecallCurve] +) +def test_raises_errors_and_warning(metric_class): + if metric_class == BinnedRecallAtFixedPrecision: + metric_class = partial(metric_class, min_precision=0.5) + + with pytest.warns( + DeprecationWarning, + match="Argument `num_thresholds` " + "is deprecated in v0.4 and will be removed in v0.5. Use `thresholds` instead." + ): + metric_class(num_classes=10, num_thresholds=100) + + with pytest.raises( + ValueError, match="Expected argument `thresholds` to either" + " be an integer, list of floats or a tensor" + ): + metric_class(num_classes=10, thresholds={'temp': [10, 20]}) diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index 7031915ab3c..19081c6baa4 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, List, Optional, Tuple, Union +from warnings import warn import torch from torch import Tensor @@ -52,8 +53,14 @@ class BinnedPrecisionRecallCurve(Metric): Args: num_classes: integer with number of classes. For binary, set to 1. - num_thresholds: number of bins used for computation. More bins will lead to more detailed - curve and accurate estimates, but will be slower and consume more memory. Default 100 + num_thresholds: number of bins used for computation. + + .. deprecated:: v0.4 + Use `thresholds`. Will be removed in v0.5. + + thresholds: list or tensor with specific thresholds or a number of bins from linear sampling. + It is used for computation will lead to more detailed curve and accurate estimates, + but will be slower and consume more memory. compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -62,11 +69,15 @@ class BinnedPrecisionRecallCurve(Metric): process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + Raises: + ValueError: + If ``thresholds`` is not a int, list or tensor + Example (binary case): >>> from torchmetrics import BinnedPrecisionRecallCurve >>> pred = torch.tensor([0, 0.1, 0.8, 0.4]) >>> target = torch.tensor([0, 1, 1, 0]) - >>> pr_curve = BinnedPrecisionRecallCurve(num_classes=1, num_thresholds=5) + >>> pr_curve = BinnedPrecisionRecallCurve(num_classes=1, thresholds=5) >>> precision, recall, thresholds = pr_curve(pred, target) >>> precision tensor([0.5000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000]) @@ -81,7 +92,7 @@ class BinnedPrecisionRecallCurve(Metric): ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> pr_curve = BinnedPrecisionRecallCurve(num_classes=5, num_thresholds=3) + >>> pr_curve = BinnedPrecisionRecallCurve(num_classes=5, thresholds=3) >>> precision, recall, thresholds = pr_curve(pred, target) >>> precision # doctest: +NORMALIZE_WHITESPACE [tensor([0.2500, 1.0000, 1.0000, 1.0000]), @@ -106,10 +117,11 @@ class BinnedPrecisionRecallCurve(Metric): def __init__( self, num_classes: int, - num_thresholds: int = 100, + thresholds: Optional[Union[Tensor, List[float]]] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + num_thresholds: Optional[int] = 100, # ToDo: remove in v0.5 ): super().__init__( compute_on_step=compute_on_step, @@ -118,14 +130,27 @@ def __init__( ) self.num_classes = num_classes - self.num_thresholds = num_thresholds - thresholds = torch.linspace(0, 1.0, num_thresholds) - self.register_buffer("thresholds", thresholds) + if thresholds is None and num_thresholds is not None: + warn( + "Argument `num_thresholds` is deprecated in v0.4 and will be removed in v0.5." + " Use `thresholds` instead.", DeprecationWarning + ) + thresholds = num_thresholds + if isinstance(thresholds, int): + self.num_thresholds = thresholds + thresholds = torch.linspace(0, 1.0, thresholds) + self.register_buffer("thresholds", thresholds) + elif thresholds is not None: + if not isinstance(thresholds, (list, Tensor)): + raise ValueError('Expected argument `thresholds` to either be an integer, list of floats or a tensor') + thresholds = torch.tensor(thresholds) if isinstance(thresholds, list) else thresholds + self.num_thresholds = thresholds.numel() + self.register_buffer("thresholds", thresholds) for name in ("TPs", "FPs", "FNs"): self.add_state( name=name, - default=torch.zeros(num_classes, num_thresholds, dtype=torch.float32), + default=torch.zeros(num_classes, self.num_thresholds, dtype=torch.float32), dist_reduce_fx="sum", ) @@ -185,13 +210,23 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve): Args: num_classes: integer with number of classes. Not nessesary to provide for binary problems. - num_thresholds: number of bins used for computation. More bins will lead to more detailed - curve and accurate estimates, but will be slower and consume more memory. Default 100 + num_thresholds: number of bins used for computation. + + .. deprecated:: v0.4 + Use `thresholds`. Will be removed in v0.5. + + thresholds: list or tensor with specific thresholds or a number of bins from linear sampling. + It is used for computation will lead to more detailed curve and accurate estimates, + but will be slower and consume more memory compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + Raises: + ValueError: + If ``thresholds`` is not a list or tensor + Example (binary case): >>> from torchmetrics import BinnedAveragePrecision >>> pred = torch.tensor([0, 1, 2, 3]) @@ -233,13 +268,23 @@ class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve): Args: num_classes: integer with number of classes. Provide 1 for for binary problems. min_precision: float value specifying minimum precision threshold. - num_thresholds: number of bins used for computation. More bins will lead to more detailed - curve and accurate estimates, but will be slower and consume more memory. Default 100 + num_thresholds: number of bins used for computation. + + .. deprecated:: v0.4 + Use `thresholds`. Will be removed in v0.5. + + thresholds: list or tensor with specific thresholds or a number of bins from linear sampling. + It is used for computation will lead to more detailed curve and accurate estimates, + but will be slower and consume more memory compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + Raises: + ValueError: + If ``thresholds`` is not a list or tensor + Example (binary case): >>> from torchmetrics import BinnedRecallAtFixedPrecision >>> pred = torch.tensor([0, 0.2, 0.5, 0.8]) @@ -264,14 +309,16 @@ def __init__( self, num_classes: int, min_precision: float, - num_thresholds: int = 100, + thresholds: Optional[Union[Tensor, List[float]]] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + num_thresholds: int = 100, # ToDo: remove in v0.5 ): super().__init__( num_classes=num_classes, num_thresholds=num_thresholds, + thresholds=thresholds, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, diff --git a/torchmetrics/regression/psnr.py b/torchmetrics/regression/psnr.py index 8e602eba7f2..e80a6ccd858 100644 --- a/torchmetrics/regression/psnr.py +++ b/torchmetrics/regression/psnr.py @@ -19,8 +19,8 @@ class PSNR(_PSNR): """ - .. deprecated:: 0.4 - The PSNR was moved to `torchmetrics.image.psnr`. + .. deprecated:: v0.4 + The PSNR was moved to `torchmetrics.image.psnr`. It will be removed in v0.5. """ diff --git a/torchmetrics/regression/ssim.py b/torchmetrics/regression/ssim.py index 27f10d207ae..c8507cf8e3c 100644 --- a/torchmetrics/regression/ssim.py +++ b/torchmetrics/regression/ssim.py @@ -19,8 +19,8 @@ class SSIM(_SSIM): """ - .. deprecated:: 0.4 - The SSIM was moved to `torchmetrics.image.ssim`. + .. deprecated:: v0.4 + The SSIM was moved to `torchmetrics.image.ssim`. It will be removed in v0.5. """