diff --git a/torchmetrics/classification/f_beta.py b/torchmetrics/classification/f_beta.py index 3e41b6146ad..a79102a45d6 100644 --- a/torchmetrics/classification/f_beta.py +++ b/torchmetrics/classification/f_beta.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional -from warnings import warn import torch from torch import Tensor from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.f_beta import _fbeta_compute +from torchmetrics.utilities import _deprecation_warn_arg_multilabel class FBeta(StatScores): @@ -114,6 +114,9 @@ class FBeta(StatScores): dist_sync_fn: Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the allgather. + multilabel: + .. deprecated:: 0.3 + Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. Raises: ValueError: @@ -143,14 +146,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, - is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 + multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4 ): - if is_multiclass is not None and multiclass is None: - warn( - "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", - DeprecationWarning - ) - multiclass = is_multiclass + _deprecation_warn_arg_multilabel(multilabel) self.beta = beta allowed_average = ["micro", "macro", "weighted", "samples", "none", None] @@ -269,6 +267,10 @@ class F1(FBeta): dist_sync_fn: Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the allgather. + multilabel: + .. deprecated:: 0.3 + Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. + Example: >>> from torchmetrics import F1 @@ -292,14 +294,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, - is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 + multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4 ): - if is_multiclass is not None and multiclass is None: - warn( - "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", - DeprecationWarning - ) - multiclass = is_multiclass + _deprecation_warn_arg_multilabel(multilabel) super().__init__( num_classes=num_classes, diff --git a/torchmetrics/classification/precision_recall.py b/torchmetrics/classification/precision_recall.py index de137262878..8e0b51dd902 100644 --- a/torchmetrics/classification/precision_recall.py +++ b/torchmetrics/classification/precision_recall.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional -from warnings import warn import torch from torch import Tensor from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute +from torchmetrics.utilities import _deprecation_warn_arg_multilabel class Precision(StatScores): @@ -104,6 +104,9 @@ class Precision(StatScores): dist_sync_fn: Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the allgather. + multilabel: + .. deprecated:: 0.3 + Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. Raises: ValueError: @@ -135,14 +138,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, - is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 + multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4 ): - if is_multiclass is not None and multiclass is None: - warn( - "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", - DeprecationWarning - ) - multiclass = is_multiclass + _deprecation_warn_arg_multilabel(multilabel) allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: @@ -263,6 +261,9 @@ class Recall(StatScores): dist_sync_fn: Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the allgather. + multilabel: + .. deprecated:: 0.3 + Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. Raises: ValueError: @@ -294,14 +295,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, - is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 + multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4 ): - if is_multiclass is not None and multiclass is None: - warn( - "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", - DeprecationWarning - ) - multiclass = is_multiclass + _deprecation_warn_arg_multilabel(multilabel) allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index 9e896401bb1..a9c8fb7fb17 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional, Tuple -from warnings import warn import numpy as np import torch @@ -146,14 +145,7 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, - is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ): - if is_multiclass is not None and multiclass is None: - warn( - "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", - DeprecationWarning - ) - multiclass = is_multiclass super().__init__( compute_on_step=compute_on_step, diff --git a/torchmetrics/functional/classification/f_beta.py b/torchmetrics/functional/classification/f_beta.py index 6e0ab627a37..c37806d0cdc 100644 --- a/torchmetrics/functional/classification/f_beta.py +++ b/torchmetrics/functional/classification/f_beta.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional -from warnings import warn import torch from torch import Tensor from torchmetrics.classification.stat_scores import _reduce_stat_scores from torchmetrics.functional.classification.stat_scores import _stat_scores_update +from torchmetrics.utilities import _deprecation_warn_arg_multilabel from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod @@ -82,7 +82,7 @@ def fbeta( threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 + multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4 ) -> Tensor: r""" Computes f_beta metric. @@ -158,6 +158,9 @@ def fbeta( than what they appear to be. See the parameter's :ref:`documentation section ` for a more detailed explanation and examples. + multilabel: + .. deprecated:: 0.3 + Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. Return: The shape of the returned tensor depends on the ``average`` parameter @@ -174,12 +177,7 @@ def fbeta( tensor(0.3333) """ - if is_multiclass is not None and multiclass is None: - warn( - "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", - DeprecationWarning - ) - multiclass = is_multiclass + _deprecation_warn_arg_multilabel(multilabel) allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: @@ -222,7 +220,7 @@ def f1( threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 + multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4 ) -> Tensor: """ Computes F1 metric. F1 metrics correspond to a equally weighted average of the @@ -301,6 +299,9 @@ def f1( than what they appear to be. See the parameter's :ref:`documentation section ` for a more detailed explanation and examples. + multilabel: + .. deprecated:: 0.3 + Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. Return: The shape of the returned tensor depends on the ``average`` parameter @@ -316,10 +317,5 @@ def f1( >>> f1(preds, target, num_classes=3) tensor(0.3333) """ - if is_multiclass is not None and multiclass is None: - warn( - "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", - DeprecationWarning - ) - multiclass = is_multiclass + _deprecation_warn_arg_multilabel(multilabel) return fbeta(preds, target, 1.0, average, mdmc_average, ignore_index, num_classes, threshold, top_k, multiclass) diff --git a/torchmetrics/functional/classification/precision_recall.py b/torchmetrics/functional/classification/precision_recall.py index b4d530c1e29..331792ba74f 100644 --- a/torchmetrics/functional/classification/precision_recall.py +++ b/torchmetrics/functional/classification/precision_recall.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple -from warnings import warn import torch from torch import Tensor from torchmetrics.classification.stat_scores import _reduce_stat_scores from torchmetrics.functional.classification.stat_scores import _stat_scores_update +from torchmetrics.utilities import _deprecation_warn_arg_multilabel def _precision_compute( @@ -49,7 +49,7 @@ def precision( threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 + multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4 ) -> Tensor: r""" Computes `Precision `_: @@ -124,6 +124,9 @@ def precision( than what they appear to be. See the parameter's :ref:`documentation section ` for a more detailed explanation and examples. + multilabel: + .. deprecated:: 0.3 + Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. Return: The shape of the returned tensor depends on the ``average`` parameter @@ -154,12 +157,7 @@ def precision( tensor(0.2500) """ - if is_multiclass is not None and multiclass is None: - warn( - "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", - DeprecationWarning - ) - multiclass = is_multiclass + _deprecation_warn_arg_multilabel(multilabel) allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: @@ -220,7 +218,7 @@ def recall( threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 + multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4 ) -> Tensor: r""" Computes `Recall `_: @@ -295,6 +293,9 @@ def recall( than what they appear to be. See the parameter's :ref:`documentation section ` for a more detailed explanation and examples. + multilabel: + .. deprecated:: 0.3 + Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. Return: The shape of the returned tensor depends on the ``average`` parameter @@ -325,12 +326,7 @@ def recall( tensor(0.2500) """ - if is_multiclass is not None and multiclass is None: - warn( - "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", - DeprecationWarning - ) - multiclass = is_multiclass + _deprecation_warn_arg_multilabel(multilabel) allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: @@ -372,7 +368,7 @@ def precision_recall( threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 + multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4 ) -> Tuple[Tensor, Tensor]: r""" Computes `Precision and Recall `_: @@ -450,6 +446,9 @@ def precision_recall( than what they appear to be. See the parameter's :ref:`documentation section ` for a more detailed explanation and examples. + multilabel: + .. deprecated:: 0.3 + Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. Return: The function returns a tuple with two elements: precision and recall. Their shape @@ -481,12 +480,7 @@ def precision_recall( (tensor(0.2500), tensor(0.2500)) """ - if is_multiclass is not None and multiclass is None: - warn( - "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", - DeprecationWarning - ) - multiclass = is_multiclass + _deprecation_warn_arg_multilabel(multilabel) allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: diff --git a/torchmetrics/functional/classification/stat_scores.py b/torchmetrics/functional/classification/stat_scores.py index 9054c3b6433..e853370d5ad 100644 --- a/torchmetrics/functional/classification/stat_scores.py +++ b/torchmetrics/functional/classification/stat_scores.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple -from warnings import warn import torch from torch import Tensor, tensor @@ -85,14 +84,7 @@ def _stat_scores_update( threshold: float = 0.5, multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, - is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - if is_multiclass is not None and multiclass is None: - warn( - "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", - DeprecationWarning - ) - multiclass = is_multiclass preds, target, _ = _input_format_classification( preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k @@ -155,7 +147,6 @@ def stat_scores( threshold: float = 0.5, multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, - is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ) -> Tensor: """Computes the number of true positives, false positives, true negatives, false negatives. Related to `Type I and Type II errors `__ @@ -280,12 +271,6 @@ def stat_scores( >>> stat_scores(preds, target, reduce='micro') tensor([2, 2, 6, 2, 4]) """ - if is_multiclass is not None and multiclass is None: - warn( - "Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.", - DeprecationWarning - ) - multiclass = is_multiclass if reduce not in ["micro", "macro", "samples"]: raise ValueError(f"The `reduce` {reduce} is not valid.") diff --git a/torchmetrics/utilities/__init__.py b/torchmetrics/utilities/__init__.py index dff18c0f389..a853ad44c7c 100644 --- a/torchmetrics/utilities/__init__.py +++ b/torchmetrics/utilities/__init__.py @@ -1,3 +1,15 @@ +from typing import Any +from warnings import warn + from torchmetrics.utilities.data import apply_to_collection # noqa: F401 from torchmetrics.utilities.distributed import class_reduce, reduce # noqa: F401 from torchmetrics.utilities.prints import rank_zero_debug, rank_zero_info, rank_zero_warn # noqa: F401 + + +def _deprecation_warn_arg_multilabel(arg: Any) -> None: + if arg is None: + return + warn( + "Argument `multilabel` was deprecated in v0.3 and will be removed in v0.4. Use `multiclass` instead.", + DeprecationWarning + )