Skip to content

Commit

Permalink
Replace thredshold argument to binned metrics (#322)
Browse files Browse the repository at this point in the history
* enhancement

* Update CHANGELOG.md

* depr

* fix

* fix tests

* docs

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
5 people authored Jun 28, 2021
1 parent 578c8f5 commit 67eb71b
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 21 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Inception Score metric to image module ([#299](https://github.com/PyTorchLightning/metrics/pull/299))
- Added KID metric to image module ([#301](https://github.com/PyTorchLightning/metrics/pull/301))
- Added `sync` and `sync_context` methods for manually controlling when metric states are synced ([#302](https://github.com/PyTorchLightning/metrics/pull/302))
- Added `thresholds` argument to binned metrics for manually controlling the thresholds ([#322](https://github.com/PyTorchLightning/metrics/pull/322))
- Added `KLDivergence` metric ([#247](https://github.com/PyTorchLightning/metrics/pull/247))

### Changed
Expand All @@ -34,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Deprecated

- Deprecated `torchmetrics.functional.mean_relative_error` ([#248](https://github.com/PyTorchLightning/metrics/pull/248))
- Deprecated `num_thresholds` argument in `BinnedPrecisionRecallCurve` ([#322](https://github.com/PyTorchLightning/metrics/pull/322))

### Removed

Expand Down
34 changes: 31 additions & 3 deletions tests/classification/test_binned_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
from tests.classification.inputs import _input_multilabel_prob_plausible as _input_mlb_prob_ok
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, MetricTester
from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision, BinnedRecallAtFixedPrecision
from torchmetrics.classification.binned_precision_recall import (
BinnedAveragePrecision,
BinnedPrecisionRecallCurve,
BinnedRecallAtFixedPrecision,
)

seed_all(42)

Expand Down Expand Up @@ -112,8 +116,10 @@ class TestBinnedAveragePrecision(MetricTester):

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("num_thresholds", [101, 301])
def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, num_thresholds):
@pytest.mark.parametrize(
"num_thresholds, thresholds", ([101, None], [301, None], [None, torch.linspace(0.0, 1.0, 101)])
)
def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, num_thresholds, thresholds):
# rounding will simulate binning for both implementations
preds = Tensor(np.round(preds.numpy(), 2)) + 1e-6

Expand All @@ -127,5 +133,27 @@ def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_o
metric_args={
"num_classes": num_classes,
"num_thresholds": num_thresholds,
"thresholds": thresholds
},
)


@pytest.mark.parametrize(
"metric_class", [BinnedAveragePrecision, BinnedRecallAtFixedPrecision, BinnedPrecisionRecallCurve]
)
def test_raises_errors_and_warning(metric_class):
if metric_class == BinnedRecallAtFixedPrecision:
metric_class = partial(metric_class, min_precision=0.5)

with pytest.warns(
DeprecationWarning,
match="Argument `num_thresholds` "
"is deprecated in v0.4 and will be removed in v0.5. Use `thresholds` instead."
):
metric_class(num_classes=10, num_thresholds=100)

with pytest.raises(
ValueError, match="Expected argument `thresholds` to either"
" be an integer, list of floats or a tensor"
):
metric_class(num_classes=10, thresholds={'temp': [10, 20]})
75 changes: 61 additions & 14 deletions torchmetrics/classification/binned_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union
from warnings import warn

import torch
from torch import Tensor
Expand Down Expand Up @@ -52,8 +53,14 @@ class BinnedPrecisionRecallCurve(Metric):
Args:
num_classes: integer with number of classes. For binary, set to 1.
num_thresholds: number of bins used for computation. More bins will lead to more detailed
curve and accurate estimates, but will be slower and consume more memory. Default 100
num_thresholds: number of bins used for computation.
.. deprecated:: v0.4
Use `thresholds`. Will be removed in v0.5.
thresholds: list or tensor with specific thresholds or a number of bins from linear sampling.
It is used for computation will lead to more detailed curve and accurate estimates,
but will be slower and consume more memory.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Expand All @@ -62,11 +69,15 @@ class BinnedPrecisionRecallCurve(Metric):
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Raises:
ValueError:
If ``thresholds`` is not a int, list or tensor
Example (binary case):
>>> from torchmetrics import BinnedPrecisionRecallCurve
>>> pred = torch.tensor([0, 0.1, 0.8, 0.4])
>>> target = torch.tensor([0, 1, 1, 0])
>>> pr_curve = BinnedPrecisionRecallCurve(num_classes=1, num_thresholds=5)
>>> pr_curve = BinnedPrecisionRecallCurve(num_classes=1, thresholds=5)
>>> precision, recall, thresholds = pr_curve(pred, target)
>>> precision
tensor([0.5000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000])
Expand All @@ -81,7 +92,7 @@ class BinnedPrecisionRecallCurve(Metric):
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> pr_curve = BinnedPrecisionRecallCurve(num_classes=5, num_thresholds=3)
>>> pr_curve = BinnedPrecisionRecallCurve(num_classes=5, thresholds=3)
>>> precision, recall, thresholds = pr_curve(pred, target)
>>> precision # doctest: +NORMALIZE_WHITESPACE
[tensor([0.2500, 1.0000, 1.0000, 1.0000]),
Expand All @@ -106,10 +117,11 @@ class BinnedPrecisionRecallCurve(Metric):
def __init__(
self,
num_classes: int,
num_thresholds: int = 100,
thresholds: Optional[Union[Tensor, List[float]]] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
num_thresholds: Optional[int] = 100, # ToDo: remove in v0.5
):
super().__init__(
compute_on_step=compute_on_step,
Expand All @@ -118,14 +130,27 @@ def __init__(
)

self.num_classes = num_classes
self.num_thresholds = num_thresholds
thresholds = torch.linspace(0, 1.0, num_thresholds)
self.register_buffer("thresholds", thresholds)
if thresholds is None and num_thresholds is not None:
warn(
"Argument `num_thresholds` is deprecated in v0.4 and will be removed in v0.5."
" Use `thresholds` instead.", DeprecationWarning
)
thresholds = num_thresholds
if isinstance(thresholds, int):
self.num_thresholds = thresholds
thresholds = torch.linspace(0, 1.0, thresholds)
self.register_buffer("thresholds", thresholds)
elif thresholds is not None:
if not isinstance(thresholds, (list, Tensor)):
raise ValueError('Expected argument `thresholds` to either be an integer, list of floats or a tensor')
thresholds = torch.tensor(thresholds) if isinstance(thresholds, list) else thresholds
self.num_thresholds = thresholds.numel()
self.register_buffer("thresholds", thresholds)

for name in ("TPs", "FPs", "FNs"):
self.add_state(
name=name,
default=torch.zeros(num_classes, num_thresholds, dtype=torch.float32),
default=torch.zeros(num_classes, self.num_thresholds, dtype=torch.float32),
dist_reduce_fx="sum",
)

Expand Down Expand Up @@ -185,13 +210,23 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve):
Args:
num_classes: integer with number of classes. Not nessesary to provide
for binary problems.
num_thresholds: number of bins used for computation. More bins will lead to more detailed
curve and accurate estimates, but will be slower and consume more memory. Default 100
num_thresholds: number of bins used for computation.
.. deprecated:: v0.4
Use `thresholds`. Will be removed in v0.5.
thresholds: list or tensor with specific thresholds or a number of bins from linear sampling.
It is used for computation will lead to more detailed curve and accurate estimates,
but will be slower and consume more memory
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Raises:
ValueError:
If ``thresholds`` is not a list or tensor
Example (binary case):
>>> from torchmetrics import BinnedAveragePrecision
>>> pred = torch.tensor([0, 1, 2, 3])
Expand Down Expand Up @@ -233,13 +268,23 @@ class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve):
Args:
num_classes: integer with number of classes. Provide 1 for for binary problems.
min_precision: float value specifying minimum precision threshold.
num_thresholds: number of bins used for computation. More bins will lead to more detailed
curve and accurate estimates, but will be slower and consume more memory. Default 100
num_thresholds: number of bins used for computation.
.. deprecated:: v0.4
Use `thresholds`. Will be removed in v0.5.
thresholds: list or tensor with specific thresholds or a number of bins from linear sampling.
It is used for computation will lead to more detailed curve and accurate estimates,
but will be slower and consume more memory
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Raises:
ValueError:
If ``thresholds`` is not a list or tensor
Example (binary case):
>>> from torchmetrics import BinnedRecallAtFixedPrecision
>>> pred = torch.tensor([0, 0.2, 0.5, 0.8])
Expand All @@ -264,14 +309,16 @@ def __init__(
self,
num_classes: int,
min_precision: float,
num_thresholds: int = 100,
thresholds: Optional[Union[Tensor, List[float]]] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
num_thresholds: int = 100, # ToDo: remove in v0.5
):
super().__init__(
num_classes=num_classes,
num_thresholds=num_thresholds,
thresholds=thresholds,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/regression/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

class PSNR(_PSNR):
"""
.. deprecated:: 0.4
The PSNR was moved to `torchmetrics.image.psnr`.
.. deprecated:: v0.4
The PSNR was moved to `torchmetrics.image.psnr`. It will be removed in v0.5.
"""

Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/regression/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

class SSIM(_SSIM):
"""
.. deprecated:: 0.4
The SSIM was moved to `torchmetrics.image.ssim`.
.. deprecated:: v0.4
The SSIM was moved to `torchmetrics.image.ssim`. It will be removed in v0.5.
"""

Expand Down

0 comments on commit 67eb71b

Please sign in to comment.