Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ClassificationTask Enum #1479

Merged
merged 32 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
b7a8dc7
add enum
SkafteNicki Feb 3, 2023
aee0f12
add enum
SkafteNicki Feb 3, 2023
3bec023
gh: update templates (#1477)Co-authored-by: pre-commit-ci[bot] <66853…
Borda Feb 3, 2023
20b94d5
Merge branch 'master' into classification/enum
SkafteNicki Feb 3, 2023
0140d12
add enum
SkafteNicki Feb 3, 2023
9bd0063
add enum
SkafteNicki Feb 3, 2023
5a4caef
StrEnum
Borda Feb 3, 2023
169586f
utils 0.5.0
Borda Feb 5, 2023
2720af1
with error
Borda Feb 5, 2023
a6cd44f
links
Borda Feb 6, 2023
cbc794c
Merge branch 'master' into classification/enum
mergify[bot] Feb 6, 2023
fa0fa16
property
Borda Feb 6, 2023
64da325
_name
Borda Feb 6, 2023
f9682a0
Merge branch 'master' into classification/enum
Borda Feb 6, 2023
0ca463d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2023
7b81238
Merge branch 'master' into classification/enum
mergify[bot] Feb 6, 2023
5355483
chlog
Borda Feb 6, 2023
e2ab95d
Merge branch 'classification/enum' of https://github.com/PyTorchLight…
Borda Feb 6, 2023
fbcc0e5
docstring
Borda Feb 6, 2023
67aeb2f
Merge branch 'master' into classification/enum
mergify[bot] Feb 6, 2023
49d7d7b
Merge branch 'classification/enum' of https://github.com/PyTorchLight…
SkafteNicki Feb 6, 2023
01b693c
remove valueerror + add from_str eval
SkafteNicki Feb 6, 2023
ea8ecc0
Merge branch 'master' into classification/enum
mergify[bot] Feb 6, 2023
c62dc64
doctests
Borda Feb 6, 2023
33bd5a5
Merge branch 'master' into classification/enum
mergify[bot] Feb 7, 2023
940dd6e
docs
Borda Feb 7, 2023
8adf66f
Merge branch 'classification/enum' of https://github.com/PyTorchLight…
Borda Feb 7, 2023
5424c1f
Merge branch 'master' into classification/enum
mergify[bot] Feb 7, 2023
76dc872
-
Borda Feb 7, 2023
aca9f38
mypy
Borda Feb 7, 2023
70314f0
mypy
Borda Feb 7, 2023
2067b00
Merge branch 'master' into classification/enum
mergify[bot] Feb 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from torchmetrics.functional.classification.accuracy import _accuracy_reduce
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val

Expand Down Expand Up @@ -453,13 +454,13 @@ def __new__(
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryAccuracy(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassAccuracy(num_classes, top_k, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelAccuracy(num_labels, threshold, average, **kwargs)
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTask


class BinaryAUROC(BinaryPrecisionRecallCurve):
Expand Down Expand Up @@ -353,12 +354,12 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryAUROC(max_fpr, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassAUROC(num_classes, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelAUROC(num_labels, average, **kwargs)
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTask


class BinaryAveragePrecision(BinaryPrecisionRecallCurve):
Expand Down Expand Up @@ -357,12 +358,12 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryAveragePrecision(**kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassAveragePrecision(num_classes, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelAveragePrecision(num_labels, average, **kwargs)
raise ValueError(
Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/classification/calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTask


class BinaryCalibrationError(Metric):
Expand Down Expand Up @@ -268,9 +269,9 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"n_bins": n_bins, "norm": norm, "ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryCalibrationError(**kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassCalibrationError(num_classes, **kwargs)
raise ValueError(
Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_multiclass_cohen_kappa_arg_validation,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryCohenKappa(BinaryConfusionMatrix):
Expand Down Expand Up @@ -222,9 +223,9 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"weights": weights, "ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryCohenKappa(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassCohenKappa(num_classes, **kwargs)
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_multilabel_confusion_matrix_update,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _PLOT_OUT_TYPE, plot_confusion_matrix

Expand Down Expand Up @@ -398,12 +399,12 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"normalize": normalize, "ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryConfusionMatrix(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassConfusionMatrix(num_classes, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelConfusionMatrix(num_labels, threshold, **kwargs)
raise ValueError(
Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/classification/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTask


class MulticlassExactMatch(Metric):
Expand Down Expand Up @@ -291,10 +292,10 @@ def __new__(
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassExactMatch(num_classes, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelExactMatch(num_labels, threshold, **kwargs)
raise ValueError(f"Expected argument `task` to either be `'multiclass'` or `'multilabel'` but got {task}")
13 changes: 7 additions & 6 deletions src/torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_multilabel_fbeta_score_arg_validation,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryFBetaScore(BinaryStatScores):
Expand Down Expand Up @@ -729,13 +730,13 @@ def __new__(
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryFBetaScore(beta, threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs)
raise ValueError(
Expand Down Expand Up @@ -780,13 +781,13 @@ def __new__(
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryF1Score(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassF1Score(num_classes, top_k, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelF1Score(num_labels, threshold, average, **kwargs)
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores
from torchmetrics.functional.classification.hamming import _hamming_distance_reduce
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryHammingDistance(BinaryStatScores):
Expand Down Expand Up @@ -349,13 +350,13 @@ def __new__(
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryHammingDistance(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassHammingDistance(num_classes, top_k, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelHammingDistance(num_labels, threshold, average, **kwargs)
raise ValueError(
Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/classification/hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_multiclass_hinge_loss_update,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryHingeLoss(Metric):
Expand Down Expand Up @@ -253,9 +254,9 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryHingeLoss(squared, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs)
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_multilabel_jaccard_index_arg_validation,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryJaccardIndex(BinaryConfusionMatrix):
Expand Down Expand Up @@ -296,12 +297,12 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryJaccardIndex(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassJaccardIndex(num_classes, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelJaccardIndex(num_labels, threshold, average, **kwargs)
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/matthews_corrcoef.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix
from torchmetrics.functional.classification.matthews_corrcoef import _matthews_corrcoef_reduce
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryMatthewsCorrCoef(BinaryConfusionMatrix):
Expand Down Expand Up @@ -238,12 +239,12 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryMatthewsCorrCoef(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassMatthewsCorrCoef(num_classes, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelMatthewsCorrCoef(num_labels, threshold, **kwargs)
raise ValueError(
Expand Down
13 changes: 7 additions & 6 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores
from torchmetrics.functional.classification.precision_recall import _precision_recall_reduce
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryPrecision(BinaryStatScores):
Expand Down Expand Up @@ -620,13 +621,13 @@ def __new__(
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryPrecision(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassPrecision(num_classes, top_k, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelPrecision(num_labels, threshold, average, **kwargs)
raise ValueError(
Expand Down Expand Up @@ -676,13 +677,13 @@ def __new__(
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryRecall(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassRecall(num_classes, top_k, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelRecall(num_labels, threshold, average, **kwargs)
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTask


class BinaryPrecisionRecallCurve(Metric):
Expand Down Expand Up @@ -467,12 +468,12 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryPrecisionRecallCurve(**kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassPrecisionRecallCurve(num_classes, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelPrecisionRecallCurve(num_labels, **kwargs)
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTask


class BinaryRecallAtFixedPrecision(BinaryPrecisionRecallCurve):
Expand Down Expand Up @@ -323,14 +324,14 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryRecallAtFixedPrecision(min_precision, thresholds, ignore_index, validate_args, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassRecallAtFixedPrecision(
num_classes, min_precision, thresholds, ignore_index, validate_args, **kwargs
)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelRecallAtFixedPrecision(
num_labels, min_precision, thresholds, ignore_index, validate_args, **kwargs
Expand Down
Loading