Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

revisit deprecation of is_multiclass #193

Merged
merged 3 commits into from
Apr 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from warnings import warn

import torch
from torch import Tensor

from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.f_beta import _fbeta_compute
from torchmetrics.utilities import _deprecation_warn_arg_multilabel


class FBeta(StatScores):
Expand Down Expand Up @@ -114,6 +114,9 @@ class FBeta(StatScores):
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather.
multilabel:
.. deprecated:: 0.3
Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead.

Raises:
ValueError:
Expand Down Expand Up @@ -143,14 +146,9 @@ def __init__(
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4
multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4
):
if is_multiclass is not None and multiclass is None:
warn(
"Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.",
DeprecationWarning
)
multiclass = is_multiclass
_deprecation_warn_arg_multilabel(multilabel)

self.beta = beta
allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
Expand Down Expand Up @@ -269,6 +267,10 @@ class F1(FBeta):
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather.
multilabel:
.. deprecated:: 0.3
Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead.


Example:
>>> from torchmetrics import F1
Expand All @@ -292,14 +294,9 @@ def __init__(
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4
multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4
):
if is_multiclass is not None and multiclass is None:
warn(
"Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.",
DeprecationWarning
)
multiclass = is_multiclass
_deprecation_warn_arg_multilabel(multilabel)

super().__init__(
num_classes=num_classes,
Expand Down
26 changes: 11 additions & 15 deletions torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from warnings import warn

import torch
from torch import Tensor

from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute
from torchmetrics.utilities import _deprecation_warn_arg_multilabel


class Precision(StatScores):
Expand Down Expand Up @@ -104,6 +104,9 @@ class Precision(StatScores):
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather.
multilabel:
.. deprecated:: 0.3
Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead.

Raises:
ValueError:
Expand Down Expand Up @@ -135,14 +138,9 @@ def __init__(
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4
multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4
):
if is_multiclass is not None and multiclass is None:
warn(
"Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.",
DeprecationWarning
)
multiclass = is_multiclass
_deprecation_warn_arg_multilabel(multilabel)

allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
Expand Down Expand Up @@ -263,6 +261,9 @@ class Recall(StatScores):
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather.
multilabel:
.. deprecated:: 0.3
Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead.

Raises:
ValueError:
Expand Down Expand Up @@ -294,14 +295,9 @@ def __init__(
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4
multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4
):
if is_multiclass is not None and multiclass is None:
warn(
"Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.",
DeprecationWarning
)
multiclass = is_multiclass
_deprecation_warn_arg_multilabel(multilabel)

allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
Expand Down
8 changes: 0 additions & 8 deletions torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Tuple
from warnings import warn

import numpy as np
import torch
Expand Down Expand Up @@ -146,14 +145,7 @@ def __init__(
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4
):
if is_multiclass is not None and multiclass is None:
warn(
"Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.",
DeprecationWarning
)
multiclass = is_multiclass

super().__init__(
compute_on_step=compute_on_step,
Expand Down
26 changes: 11 additions & 15 deletions torchmetrics/functional/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from warnings import warn

import torch
from torch import Tensor

from torchmetrics.classification.stat_scores import _reduce_stat_scores
from torchmetrics.functional.classification.stat_scores import _stat_scores_update
from torchmetrics.utilities import _deprecation_warn_arg_multilabel
from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod


Expand Down Expand Up @@ -82,7 +82,7 @@ def fbeta(
threshold: float = 0.5,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4
multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4
) -> Tensor:
r"""
Computes f_beta metric.
Expand Down Expand Up @@ -158,6 +158,9 @@ def fbeta(
than what they appear to be. See the parameter's
:ref:`documentation section <references/modules:using the multiclass parameter>`
for a more detailed explanation and examples.
multilabel:
.. deprecated:: 0.3
Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead.

Return:
The shape of the returned tensor depends on the ``average`` parameter
Expand All @@ -174,12 +177,7 @@ def fbeta(
tensor(0.3333)

"""
if is_multiclass is not None and multiclass is None:
warn(
"Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.",
DeprecationWarning
)
multiclass = is_multiclass
_deprecation_warn_arg_multilabel(multilabel)

allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
Expand Down Expand Up @@ -222,7 +220,7 @@ def f1(
threshold: float = 0.5,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4
multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4
) -> Tensor:
"""
Computes F1 metric. F1 metrics correspond to a equally weighted average of the
Expand Down Expand Up @@ -301,6 +299,9 @@ def f1(
than what they appear to be. See the parameter's
:ref:`documentation section <references/modules:using the multiclass parameter>`
for a more detailed explanation and examples.
multilabel:
.. deprecated:: 0.3
Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead.

Return:
The shape of the returned tensor depends on the ``average`` parameter
Expand All @@ -316,10 +317,5 @@ def f1(
>>> f1(preds, target, num_classes=3)
tensor(0.3333)
"""
if is_multiclass is not None and multiclass is None:
warn(
"Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.",
DeprecationWarning
)
multiclass = is_multiclass
_deprecation_warn_arg_multilabel(multilabel)
return fbeta(preds, target, 1.0, average, mdmc_average, ignore_index, num_classes, threshold, top_k, multiclass)
38 changes: 16 additions & 22 deletions torchmetrics/functional/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from warnings import warn

import torch
from torch import Tensor

from torchmetrics.classification.stat_scores import _reduce_stat_scores
from torchmetrics.functional.classification.stat_scores import _stat_scores_update
from torchmetrics.utilities import _deprecation_warn_arg_multilabel


def _precision_compute(
Expand Down Expand Up @@ -49,7 +49,7 @@ def precision(
threshold: float = 0.5,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4
multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4
) -> Tensor:
r"""
Computes `Precision <https://en.wikipedia.org/wiki/Precision_and_recall>`_:
Expand Down Expand Up @@ -124,6 +124,9 @@ def precision(
than what they appear to be. See the parameter's
:ref:`documentation section <references/modules:using the multiclass parameter>`
for a more detailed explanation and examples.
multilabel:
.. deprecated:: 0.3
Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead.

Return:
The shape of the returned tensor depends on the ``average`` parameter
Expand Down Expand Up @@ -154,12 +157,7 @@ def precision(
tensor(0.2500)

"""
if is_multiclass is not None and multiclass is None:
warn(
"Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.",
DeprecationWarning
)
multiclass = is_multiclass
_deprecation_warn_arg_multilabel(multilabel)

allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
Expand Down Expand Up @@ -220,7 +218,7 @@ def recall(
threshold: float = 0.5,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4
multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4
) -> Tensor:
r"""
Computes `Recall <https://en.wikipedia.org/wiki/Precision_and_recall>`_:
Expand Down Expand Up @@ -295,6 +293,9 @@ def recall(
than what they appear to be. See the parameter's
:ref:`documentation section <references/modules:using the multiclass parameter>`
for a more detailed explanation and examples.
multilabel:
.. deprecated:: 0.3
Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead.

Return:
The shape of the returned tensor depends on the ``average`` parameter
Expand Down Expand Up @@ -325,12 +326,7 @@ def recall(
tensor(0.2500)

"""
if is_multiclass is not None and multiclass is None:
warn(
"Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.",
DeprecationWarning
)
multiclass = is_multiclass
_deprecation_warn_arg_multilabel(multilabel)

allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
Expand Down Expand Up @@ -372,7 +368,7 @@ def precision_recall(
threshold: float = 0.5,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4
multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4
) -> Tuple[Tensor, Tensor]:
r"""
Computes `Precision and Recall <https://en.wikipedia.org/wiki/Precision_and_recall>`_:
Expand Down Expand Up @@ -450,6 +446,9 @@ def precision_recall(
than what they appear to be. See the parameter's
:ref:`documentation section <references/modules:using the multiclass parameter>`
for a more detailed explanation and examples.
multilabel:
.. deprecated:: 0.3
Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead.

Return:
The function returns a tuple with two elements: precision and recall. Their shape
Expand Down Expand Up @@ -481,12 +480,7 @@ def precision_recall(
(tensor(0.2500), tensor(0.2500))

"""
if is_multiclass is not None and multiclass is None:
warn(
"Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.",
DeprecationWarning
)
multiclass = is_multiclass
_deprecation_warn_arg_multilabel(multilabel)

allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
Expand Down
15 changes: 0 additions & 15 deletions torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from warnings import warn

import torch
from torch import Tensor, tensor
Expand Down Expand Up @@ -85,14 +84,7 @@ def _stat_scores_update(
threshold: float = 0.5,
multiclass: Optional[bool] = None,
ignore_index: Optional[int] = None,
is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
if is_multiclass is not None and multiclass is None:
warn(
"Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.",
DeprecationWarning
)
multiclass = is_multiclass

preds, target, _ = _input_format_classification(
preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k
Expand Down Expand Up @@ -155,7 +147,6 @@ def stat_scores(
threshold: float = 0.5,
multiclass: Optional[bool] = None,
ignore_index: Optional[int] = None,
is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4
) -> Tensor:
"""Computes the number of true positives, false positives, true negatives, false negatives.
Related to `Type I and Type II errors <https://en.wikipedia.org/wiki/Type_I_and_type_II_errors>`__
Expand Down Expand Up @@ -280,12 +271,6 @@ def stat_scores(
>>> stat_scores(preds, target, reduce='micro')
tensor([2, 2, 6, 2, 4])
"""
if is_multiclass is not None and multiclass is None:
warn(
"Argument `is_multiclass` was deprecated in v0.3.0 and will be removed in v0.4. Use `multiclass`.",
DeprecationWarning
)
multiclass = is_multiclass

if reduce not in ["micro", "macro", "samples"]:
raise ValueError(f"The `reduce` {reduce} is not valid.")
Expand Down
Loading