diff --git a/CHANGELOG.md b/CHANGELOG.md index e53ec4663bd..a29098d38d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `classes` to output from `MAP` metric ([#1419](https://github.com/Lightning-AI/metrics/pull/1419)) +- Add `ClassificationTask` Enum and use in metrics ([#1479](https://github.com/Lightning-AI/metrics/pull/1479)) + + ### Changed - Changed `update_count` and `update_called` from private to public methods ([#1370](https://github.com/Lightning-AI/metrics/pull/1370)) @@ -31,6 +34,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Raise exception for invalid kwargs in Metric base class ([#1427](https://github.com/Lightning-AI/metrics/pull/1427)) +- Extend `EnumStr` raising `ValueError` for invalid value ([#1479](https://github.com/Lightning-AI/metrics/pull/1479)) + + ### Deprecated - diff --git a/requirements.txt b/requirements.txt index a121c783161..4b5ec226720 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,4 @@ numpy>=1.17.2 torch>=1.8.1 typing-extensions; python_version < '3.9' packaging # hotfix for utils, can be dropped with lit-utils >=0.5 -lightning-utilities>=0.4.1 +lightning-utilities>=0.5.0 diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 798e8e0eacc..f68a05db6db 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -18,6 +18,7 @@ from torchmetrics.functional.classification.accuracy import _accuracy_reduce from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val @@ -490,18 +491,16 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryAccuracy(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassAccuracy(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelAccuracy(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 85768340a23..14bb0b2e953 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -31,6 +31,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryAUROC(BinaryPrecisionRecallCurve): @@ -352,15 +353,13 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryAUROC(max_fpr, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassAUROC(num_classes, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelAUROC(num_labels, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index 979e11e5db7..c50d4708bf4 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -30,6 +30,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryAveragePrecision(BinaryPrecisionRecallCurve): @@ -356,15 +357,13 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryAveragePrecision(**kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassAveragePrecision(num_classes, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelAveragePrecision(num_labels, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index 544fd03debc..3518870d7fa 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -29,6 +29,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel class BinaryCalibrationError(Metric): @@ -267,12 +268,10 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTaskNoMultilabel.from_str(task) kwargs.update({"n_bins": n_bins, "norm": norm, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTaskNoMultilabel.BINARY: return BinaryCalibrationError(**kwargs) - if task == "multiclass": + if task == ClassificationTaskNoMultilabel.MULTICLASS: assert isinstance(num_classes, int) return MulticlassCalibrationError(num_classes, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index f2c9bcc0b1b..524d8ff5d7d 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -23,6 +23,7 @@ _multiclass_cohen_kappa_arg_validation, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel class BinaryCohenKappa(BinaryConfusionMatrix): @@ -221,12 +222,10 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTaskNoMultilabel.from_str(task) kwargs.update({"weights": weights, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTaskNoMultilabel.BINARY: return BinaryCohenKappa(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTaskNoMultilabel.MULTICLASS: assert isinstance(num_classes, int) return MulticlassCohenKappa(num_classes, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 7018956ed6a..2f9277d76a1 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -35,6 +35,7 @@ _multilabel_confusion_matrix_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _PLOT_OUT_TYPE, plot_confusion_matrix @@ -397,15 +398,13 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) kwargs.update({"normalize": normalize, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryConfusionMatrix(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassConfusionMatrix(num_classes, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelConfusionMatrix(num_labels, threshold, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index 5c25e6e9d09..82dfa3e8747 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -32,6 +32,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTaskNoBinary class MulticlassExactMatch(Metric): @@ -288,13 +289,13 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTaskNoBinary.from_str(task) kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "multiclass": + if task == ClassificationTaskNoBinary.MULTICLASS: assert isinstance(num_classes, int) return MulticlassExactMatch(num_classes, **kwargs) - if task == "multilabel": + if task == ClassificationTaskNoBinary.MULTILABEL: assert isinstance(num_labels, int) return MultilabelExactMatch(num_labels, threshold, **kwargs) - raise ValueError(f"Expected argument `task` to either be `'multiclass'` or `'multilabel'` but got {task}") diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index c3dfa4a5b26..4dd02190e27 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -24,6 +24,7 @@ _multilabel_fbeta_score_arg_validation, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryFBetaScore(BinaryStatScores): @@ -729,13 +730,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryFBetaScore(beta, threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs) raise ValueError( @@ -776,19 +777,17 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) assert multidim_average is not None kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryF1Score(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassF1Score(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelF1Score(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index 636215ae145..ccdd13400f8 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -19,6 +19,7 @@ from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.hamming import _hamming_distance_reduce from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryHammingDistance(BinaryStatScores): @@ -344,20 +345,17 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: - + task = ClassificationTask.from_str(task) assert multidim_average is not None kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryHammingDistance(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassHammingDistance(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelHammingDistance(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index 1d59b60621e..b68584e7f1b 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -29,6 +29,7 @@ _multiclass_hinge_loss_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel class BinaryHingeLoss(Metric): @@ -252,12 +253,10 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTaskNoMultilabel.from_str(task) kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTaskNoMultilabel.BINARY: return BinaryHingeLoss(squared, **kwargs) - if task == "multiclass": + if task == ClassificationTaskNoMultilabel.MULTICLASS: assert isinstance(num_classes, int) return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 2b7e43f3c9c..ae1c37fd93c 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -23,6 +23,7 @@ _multilabel_jaccard_index_arg_validation, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryJaccardIndex(BinaryConfusionMatrix): @@ -295,15 +296,13 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryJaccardIndex(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassJaccardIndex(num_classes, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelJaccardIndex(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index 39f4d40a169..cceaa4e4b43 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -19,6 +19,7 @@ from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.functional.classification.matthews_corrcoef import _matthews_corrcoef_reduce from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryMatthewsCorrCoef(BinaryConfusionMatrix): @@ -237,15 +238,13 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryMatthewsCorrCoef(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassMatthewsCorrCoef(num_classes, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelMatthewsCorrCoef(num_labels, threshold, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 4964e013f4a..300dc3364aa 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -19,6 +19,7 @@ from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.precision_recall import _precision_recall_reduce from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryPrecision(BinaryStatScores): @@ -620,18 +621,16 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: return BinaryPrecision(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassPrecision(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelPrecision(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) class Recall: @@ -672,19 +671,17 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) assert multidim_average is not None kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryRecall(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassRecall(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelRecall(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index e26309413d1..dafdea2e6cf 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -37,6 +37,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryPrecisionRecallCurve(Metric): @@ -466,15 +467,13 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryPrecisionRecallCurve(**kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassPrecisionRecallCurve(num_classes, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelPrecisionRecallCurve(num_labels, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/recall_at_fixed_precision.py b/src/torchmetrics/classification/recall_at_fixed_precision.py index 4b79171f1b8..188b71ce272 100644 --- a/src/torchmetrics/classification/recall_at_fixed_precision.py +++ b/src/torchmetrics/classification/recall_at_fixed_precision.py @@ -31,6 +31,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryRecallAtFixedPrecision(BinaryPrecisionRecallCurve): @@ -323,18 +324,16 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task == "binary": + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: return BinaryRecallAtFixedPrecision(min_precision, thresholds, ignore_index, validate_args, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassRecallAtFixedPrecision( num_classes, min_precision, thresholds, ignore_index, validate_args, **kwargs ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelRecallAtFixedPrecision( num_labels, min_precision, thresholds, ignore_index, validate_args, **kwargs ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index a8358b165a9..aa7d86c1cc0 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -28,6 +28,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryROC(BinaryPrecisionRecallCurve): @@ -381,15 +382,13 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryROC(**kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassROC(num_classes, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelROC(num_labels, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index 3468196812d..0965a99bffc 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -19,6 +19,7 @@ from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.specificity import _specificity_reduce from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinarySpecificity(BinaryStatScores): @@ -320,19 +321,17 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) assert multidim_average is not None kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinarySpecificity(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassSpecificity(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelSpecificity(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/specificity_at_sensitivity.py b/src/torchmetrics/classification/specificity_at_sensitivity.py index b0fed56213c..05f4aa03c14 100644 --- a/src/torchmetrics/classification/specificity_at_sensitivity.py +++ b/src/torchmetrics/classification/specificity_at_sensitivity.py @@ -31,6 +31,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinarySpecificityAtSensitivity(BinaryPrecisionRecallCurve): @@ -327,18 +328,16 @@ def __new__( # type: ignore validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task == "binary": + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: return BinarySpecificityAtSensitivity(min_sensitivity, thresholds, ignore_index, validate_args, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassSpecificityAtSensitivity( num_classes, min_sensitivity, thresholds, ignore_index, validate_args, **kwargs ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelSpecificityAtSensitivity( num_labels, min_sensitivity, thresholds, ignore_index, validate_args, **kwargs ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 1c70ec32365..8fd61a10748 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -36,6 +36,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class _AbstractStatScores(Metric): @@ -495,19 +496,17 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) assert multidim_average is not None kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryStatScores(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassStatScores(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelStatScores(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index 61b6c778054..109ebf35c41 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -32,6 +32,7 @@ _multilabel_stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _accuracy_reduce( @@ -396,20 +397,19 @@ def accuracy( >>> accuracy(preds, target, task="multiclass", num_classes=3, top_k=2) tensor(0.6667) """ + task = ClassificationTask.from_str(task) assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_accuracy(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_accuracy( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_accuracy( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index 7565f6f99c3..f5548f8634b 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -38,6 +38,7 @@ ) from torchmetrics.utilities.compute import _auc_compute_without_check, _safe_divide from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.prints import rank_zero_warn @@ -450,14 +451,12 @@ def auroc( >>> auroc(preds, target, task='multiclass', num_classes=3) tensor(0.7778) """ - if task == "binary": + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: return binary_auroc(preds, target, max_fpr, thresholds, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_auroc(preds, target, num_classes, average, thresholds, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_auroc(preds, target, num_labels, average, thresholds, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index e481ea44ab5..a905f14bc49 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -36,6 +36,7 @@ ) from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.prints import rank_zero_warn @@ -437,16 +438,14 @@ def average_precision( >>> average_precision(pred, target, task="multiclass", num_classes=5, average=None) tensor([1.0000, 1.0000, 0.2500, 0.2500, nan]) """ - if task == "binary": + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: return binary_average_precision(preds, target, thresholds, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_average_precision( preds, target, num_classes, average, thresholds, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_average_precision(preds, target, num_labels, average, thresholds, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index 97428675ee3..ef32dd14f14 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -23,6 +23,7 @@ _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, ) +from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel def _binning_bucketize( @@ -347,10 +348,11 @@ def calibration_error( :func:`binary_calibration_error` and :func:`multiclass_calibration_error` for the specific details of each argument influence and examples. """ + task = ClassificationTaskNoMultilabel.from_str(task) assert norm is not None - if task == "binary": + if task == ClassificationTaskNoMultilabel.BINARY: return binary_calibration_error(preds, target, n_bins, norm, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTaskNoMultilabel.MULTICLASS: assert isinstance(num_classes, int) return multiclass_calibration_error(preds, target, num_classes, n_bins, norm, ignore_index, validate_args) raise ValueError(f"Expected argument `task` to either be `'binary'` or `'multiclass'` but got {task}") diff --git a/src/torchmetrics/functional/classification/cohen_kappa.py b/src/torchmetrics/functional/classification/cohen_kappa.py index 75bf6b53b4f..3a771858ae1 100644 --- a/src/torchmetrics/functional/classification/cohen_kappa.py +++ b/src/torchmetrics/functional/classification/cohen_kappa.py @@ -27,6 +27,7 @@ _multiclass_confusion_matrix_tensor_validation, _multiclass_confusion_matrix_update, ) +from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel def _cohen_kappa_reduce(confmat: Tensor, weights: Optional[Literal["linear", "quadratic", "none"]] = None) -> Tensor: @@ -256,9 +257,10 @@ class labels. >>> cohen_kappa(preds, target, task="multiclass", num_classes=2) tensor(0.5000) """ - if task == "binary": + task = ClassificationTaskNoMultilabel.from_str(task) + if task == ClassificationTaskNoMultilabel.BINARY: return binary_cohen_kappa(preds, target, threshold, weights, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTaskNoMultilabel.MULTICLASS: assert isinstance(num_classes, int) return multiclass_cohen_kappa(preds, target, num_classes, weights, ignore_index, validate_args) - raise ValueError(f"Expected argument `task` to either be `'binary'` or `'multiclass'` but got {task}") + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 562bbdc3042..2a04b14e723 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -19,6 +19,7 @@ from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.prints import rank_zero_warn @@ -630,14 +631,12 @@ def confusion_matrix( [[1, 0], [1, 0]], [[0, 1], [0, 1]]]) """ - if task == "binary": + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: return binary_confusion_matrix(preds, target, threshold, normalize, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_confusion_matrix(preds, target, num_classes, normalize, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_confusion_matrix(preds, target, num_labels, threshold, normalize, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/exact_match.py b/src/torchmetrics/functional/classification/exact_match.py index 89e043cbcd7..da41cb8c016 100644 --- a/src/torchmetrics/functional/classification/exact_match.py +++ b/src/torchmetrics/functional/classification/exact_match.py @@ -26,6 +26,7 @@ _multilabel_stat_scores_tensor_validation, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTaskNoBinary def _exact_match_reduce( @@ -229,12 +230,13 @@ def exact_match( >>> exact_match(preds, target, task="multiclass", num_classes=3, multidim_average='samplewise') tensor([1., 0.]) """ - if task == "multiclass": + task = ClassificationTaskNoBinary.from_str(task) + if task == ClassificationTaskNoBinary.MULTICLASS: assert num_classes is not None return multiclass_exact_match(preds, target, num_classes, multidim_average, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTaskNoBinary.MULTILABEL: assert num_labels is not None return multilabel_exact_match( preds, target, num_labels, threshold, multidim_average, ignore_index, validate_args ) - raise ValueError(f"Expected argument `task` to either be `'multiclass'` or `'multilabel'` but got {task}") + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/f_beta.py b/src/torchmetrics/functional/classification/f_beta.py index 02200d6c0f0..5a649f70dab 100644 --- a/src/torchmetrics/functional/classification/f_beta.py +++ b/src/torchmetrics/functional/classification/f_beta.py @@ -32,6 +32,7 @@ _multilabel_stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _fbeta_reduce( @@ -692,23 +693,21 @@ def fbeta_score( >>> fbeta_score(preds, target, task="multiclass", num_classes=3, beta=0.5) tensor(0.3333) """ + task = ClassificationTask.from_str(task) assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_fbeta_score(preds, target, beta, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_fbeta_score( preds, target, beta, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_fbeta_score( preds, target, beta, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) def f1_score( @@ -741,20 +740,18 @@ def f1_score( >>> f1_score(preds, target, task="multiclass", num_classes=3) tensor(0.3333) """ + task = ClassificationTask.from_str(task) assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_f1_score(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_f1_score( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_f1_score( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/hamming.py b/src/torchmetrics/functional/classification/hamming.py index 15b4733376c..e6c75379b5e 100644 --- a/src/torchmetrics/functional/classification/hamming.py +++ b/src/torchmetrics/functional/classification/hamming.py @@ -32,6 +32,7 @@ _multilabel_stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _hamming_distance_reduce( @@ -398,20 +399,19 @@ def hamming_distance( >>> hamming_distance(preds, target, task="binary") tensor(0.2500) """ + task = ClassificationTask.from_str(task) assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_hamming_distance(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_hamming_distance( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_hamming_distance( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index 5621c1441fe..72cbdd85fb2 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -24,6 +24,7 @@ _multiclass_confusion_matrix_tensor_validation, ) from torchmetrics.utilities.data import to_onehot +from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel def _hinge_loss_compute(measure: Tensor, total: Tensor) -> Tensor: @@ -276,9 +277,10 @@ def hinge_loss( >>> hinge_loss(preds, target, task="multiclass", num_classes=3, multiclass_mode="one-vs-all") tensor([1.3743, 1.1945, 1.2359]) """ - if task == "binary": + task = ClassificationTaskNoMultilabel.from_str(task) + if task == ClassificationTaskNoMultilabel.BINARY: return binary_hinge_loss(preds, target, squared, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTaskNoMultilabel.MULTICLASS: assert isinstance(num_classes, int) return multiclass_hinge_loss(preds, target, num_classes, squared, multiclass_mode, ignore_index, validate_args) - raise ValueError(f"Expected argument `task` to either be `'binary'` or `'multilabel'` but got {task}") + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 5cc33273a5c..203dccb0b09 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -32,6 +32,7 @@ _multilabel_confusion_matrix_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _jaccard_index_reduce( @@ -321,14 +322,13 @@ def jaccard_index( >>> jaccard_index(pred, target, task="multiclass", num_classes=2) tensor(0.9660) """ - if task == "binary": + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: return binary_jaccard_index(preds, target, threshold, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_jaccard_index(preds, target, num_classes, average, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_jaccard_index(preds, target, num_labels, threshold, average, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index c43b592dc1b..999746f11a4 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -31,6 +31,7 @@ _multilabel_confusion_matrix_tensor_validation, _multilabel_confusion_matrix_update, ) +from torchmetrics.utilities.enums import ClassificationTask def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: @@ -213,7 +214,7 @@ def multilabel_matthews_corrcoef( def matthews_corrcoef( preds: Tensor, target: Tensor, - task: Literal["binary", "multiclass", "multilabel"] = None, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, num_labels: Optional[int] = None, @@ -235,14 +236,13 @@ def matthews_corrcoef( >>> matthews_corrcoef(preds, target, task="multiclass", num_classes=2) tensor(0.5774) """ - if task == "binary": + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: return binary_matthews_corrcoef(preds, target, threshold, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_matthews_corrcoef(preds, target, num_classes, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_matthews_corrcoef(preds, target, num_labels, threshold, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index f04521b5b94..c131b2169b1 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -32,6 +32,7 @@ _multilabel_stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _precision_recall_reduce( @@ -652,15 +653,15 @@ def precision( tensor(0.2500) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_precision(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_precision( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_precision( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args @@ -704,20 +705,19 @@ def recall( >>> recall(preds, target, task="multiclass", average='micro', num_classes=3) tensor(0.2500) """ + task = ClassificationTask.from_str(task) assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_recall(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_recall( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_recall( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 99678d075b6..3391e7c1032 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -22,6 +22,7 @@ from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.enums import ClassificationTask def _binary_clf_curve( @@ -815,14 +816,12 @@ def precision_recall_curve( >>> thresholds [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] """ - if task == "binary": + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: return binary_precision_recall_curve(preds, target, thresholds, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_precision_recall_curve(preds, target, num_classes, thresholds, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_precision_recall_curve(preds, target, num_labels, thresholds, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/recall_at_fixed_precision.py b/src/torchmetrics/functional/classification/recall_at_fixed_precision.py index d7d907fa755..8f79370c69a 100644 --- a/src/torchmetrics/functional/classification/recall_at_fixed_precision.py +++ b/src/torchmetrics/functional/classification/recall_at_fixed_precision.py @@ -34,6 +34,7 @@ _multilabel_precision_recall_curve_tensor_validation, _multilabel_precision_recall_curve_update, ) +from torchmetrics.utilities.enums import ClassificationTask def _recall_at_precision( @@ -384,18 +385,16 @@ def recall_at_fixed_precision( :func:`binary_recall_at_fixed_precision`, :func:`multiclass_recall_at_fixed_precision` and :func:`multilabel_recall_at_fixed_precision` for the specific details of each argument influence and examples. """ - if task == "binary": + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: return binary_recall_at_fixed_precision(preds, target, min_precision, thresholds, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_recall_at_fixed_precision( preds, target, num_classes, min_precision, thresholds, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_recall_at_fixed_precision( preds, target, num_labels, min_precision, thresholds, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index 7c0fe8d6a5e..ab77240fd4b 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -34,6 +34,7 @@ ) from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _binary_roc_compute( @@ -483,14 +484,12 @@ def roc( tensor([1.0000, 0.7576, 0.3680, 0.3468, 0.0745]), tensor([1.0000, 0.1837, 0.1338, 0.1183, 0.1138])] """ - if task == "binary": + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: return binary_roc(preds, target, thresholds, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_roc(preds, target, num_classes, thresholds, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/specificity.py b/src/torchmetrics/functional/classification/specificity.py index 64e5f20b218..fa5f2b11567 100644 --- a/src/torchmetrics/functional/classification/specificity.py +++ b/src/torchmetrics/functional/classification/specificity.py @@ -32,6 +32,7 @@ _multilabel_stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _specificity_reduce( @@ -369,20 +370,19 @@ def specificity( >>> specificity(preds, target, task="multiclass", average='micro', num_classes=3) tensor(0.6250) """ + task = ClassificationTask.from_str(task) assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_specificity(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_specificity( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_specificity( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/specificity_at_sensitivity.py b/src/torchmetrics/functional/classification/specificity_at_sensitivity.py index e86a14da6ec..a97fcfa0a56 100644 --- a/src/torchmetrics/functional/classification/specificity_at_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_at_sensitivity.py @@ -36,6 +36,7 @@ _multiclass_roc_compute, _multilabel_roc_compute, ) +from torchmetrics.utilities.enums import ClassificationTask def _convert_fpr_to_specificity(fpr: Tensor) -> Tensor: @@ -413,20 +414,19 @@ def specicity_at_sensitivity( :func:`binary_specificity_at_sensitivity`, :func:`multiclass_specicity_at_sensitivity` and :func:`multilabel_specifity_at_sensitvity` for the specific details of each argument influence and examples. """ - if task == "binary": + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: return binary_specificity_at_sensitivity( # type: ignore preds, target, min_sensitivity, thresholds, ignore_index, validate_args ) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_specificity_at_sensitivity( # type: ignore preds, target, num_classes, min_sensitivity, thresholds, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_specificity_at_sensitivity( # type: ignore preds, target, num_labels, min_sensitivity, thresholds, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 1b5aa2282f8..3ca1b5b9447 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -19,7 +19,7 @@ from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification from torchmetrics.utilities.data import _bincount, select_topk -from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod +from torchmetrics.utilities.enums import AverageMethod, ClassificationTask, DataType, MDMCAverageMethod def _binary_stat_scores_arg_validation( @@ -1081,20 +1081,18 @@ def stat_scores( [1, 1, 1, 1, 2], [1, 0, 3, 0, 1]]) """ + task = ClassificationTask.from_str(task) assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_stat_scores(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_stat_scores( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_stat_scores( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 287aa006f69..8308fadbb28 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -26,46 +26,26 @@ class _MetricVariant(EnumStr): """Enumerate for metric variants.""" + @staticmethod + def _name() -> str: + return "variant" + A = "a" B = "b" C = "c" - @classmethod - def from_str(cls, value: Literal["a", "b", "c"]) -> "_MetricVariant": # type: ignore[override] - """Raises: - ValueError: - If required metric variant is not among the supported options. - """ - _allowed_variants = [im.lower() for im in _MetricVariant._member_names_] - - enum_key = super().from_str(value) - if enum_key is not None and enum_key in _allowed_variants: - return enum_key # type: ignore[return-value] # use override - raise ValueError(f"Invalid metric variant. Expected one of {_allowed_variants}, but got {enum_key}.") - class _TestAlternative(EnumStr): """Enumerate for test altenative options.""" + @staticmethod + def _name() -> str: + return "alternative" + TWO_SIDED = "two-sided" LESS = "less" GREATER = "greater" - @classmethod - def from_str(cls, value: Literal["two-sided", "less", "greater"]) -> "_TestAlternative": # type: ignore[override] - """Load from string. - - Raises: - ValueError: - If required test alternative is not among the supported options. - """ - _allowed_alternatives = [im.lower().replace("_", "-") for im in _TestAlternative._member_names_] - - enum_key = super().from_str(value.replace("-", "_")) - if enum_key is not None and enum_key in _allowed_alternatives: - return enum_key # type: ignore[return-value] # use override - raise ValueError(f"Invalid test alternative. Expected one of {_allowed_alternatives}, but got {enum_key}.") - def _sort_on_first_sequence(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: """Sort sequences in an ascent order according to the sequence ``x``.""" diff --git a/src/torchmetrics/functional/text/infolm.py b/src/torchmetrics/functional/text/infolm.py index b65d19fd139..25033d00100 100644 --- a/src/torchmetrics/functional/text/infolm.py +++ b/src/torchmetrics/functional/text/infolm.py @@ -54,6 +54,10 @@ class _IMEnum(EnumStr): """A helper Enum class for storing the information measure.""" + @staticmethod + def _name() -> str: + return "Information measure" + KL_DIVERGENCE = "kl_divergence" ALPHA_DIVERGENCE = "alpha_divergence" BETA_DIVERGENCE = "beta_divergence" @@ -64,19 +68,6 @@ class _IMEnum(EnumStr): L_INFINITY_DISTANCE = "l_infinity_distance" FISHER_RAO_DISTANCE = "fisher_rao_distance" - @classmethod - def from_str(cls, value: str) -> Optional["EnumStr"]: - """Raises: - ValueError: - If required information measure is not among the supported options. - """ - _allowed_im = [im.lower() for im in _IMEnum._member_names_] - - enum_key = super().from_str(value) - if enum_key is not None and enum_key in _allowed_im: - return enum_key - raise ValueError(f"Invalid information measure. Expected one of {_allowed_im}, but got {enum_key}.") - class _InformationMeasure: """A wrapper class used for the calculation the result of information measure between the discrete reference diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index 75fc6e61ca3..f2327854165 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -11,38 +11,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum -from typing import Optional, Union +from typing import Optional +from lightning_utilities.core.enums import StrEnum as StrEnum -class EnumStr(str, Enum): - """Type of any enumerator with allowed comparison to string invariant to cases. - Example: - >>> class MyEnum(EnumStr): - ... ABC = 'abc' - >>> MyEnum.from_str('Abc') - - >>> {MyEnum.ABC: 123} - {: 123} - """ +class EnumStr(StrEnum): + @staticmethod + def _name() -> str: + return "Task" @classmethod - def from_str(cls, value: str) -> Optional["EnumStr"]: - statuses = [status for status in dir(cls) if not status.startswith("_")] - for st in statuses: - if st.lower() == value.lower(): - return getattr(cls, st) - return None + def from_str(cls, value: str) -> "EnumStr": + """Load from string. - def __eq__(self, other: Union[str, "EnumStr", None]) -> bool: # type: ignore - other = other.value if isinstance(other, Enum) else str(other) - return self.value.lower() == other.lower() + Raises: + ValueError: + If required value is not among the supported options. - def __hash__(self) -> int: - # re-enable hashtable so it can be used as a dict key or in a set - # example: set(EnumStr) - return hash(self.name) + >>> class MyEnum(EnumStr): + ... a = "aaa" + ... b = "bbb" + >>> MyEnum.from_str("a") + + >>> MyEnum.from_str("c") + Traceback (most recent call last): + ... + ValueError: Invalid Task: expected one of ['a', 'b'], but got c. + """ + enum_key = super().from_str(value.replace("-", "_")) + if enum_key is not None: + return enum_key + _allowed_im = [m.lower() for m in cls._member_names_] + raise ValueError(f"Invalid {cls._name()}: expected one of {_allowed_im}, but got {value}.") class DataType(EnumStr): @@ -52,6 +53,10 @@ class DataType(EnumStr): True """ + @staticmethod + def _name() -> str: + return "Data type" + BINARY = "binary" MULTILABEL = "multi-label" MULTICLASS = "multi-class" @@ -69,6 +74,10 @@ class AverageMethod(EnumStr): True """ + @staticmethod + def _name() -> str: + return "Average method" + MICRO = "micro" MACRO = "macro" WEIGHTED = "weighted" @@ -79,5 +88,55 @@ class AverageMethod(EnumStr): class MDMCAverageMethod(EnumStr): """Enum to represent multi-dim multi-class average method.""" + @staticmethod + def _name() -> str: + return "MDMC Average method" + GLOBAL = "global" SAMPLEWISE = "samplewise" + + +class ClassificationTask(EnumStr): + """Enum to represent the different tasks in classification metrics. + + >>> "binary" in list(ClassificationTask) + True + """ + + @staticmethod + def _name() -> str: + return "Classification" + + BINARY = "binary" + MULTICLASS = "multiclass" + MULTILABEL = "multilabel" + + +class ClassificationTaskNoBinary(EnumStr): + """Enum to represent the different tasks in classification metrics. + + >>> "binary" in list(ClassificationTaskNoBinary) + False + """ + + @staticmethod + def _name() -> str: + return "Classification" + + MULTILABEL = "multilabel" + MULTICLASS = "multiclass" + + +class ClassificationTaskNoMultilabel(EnumStr): + """Enum to represent the different tasks in classification metrics. + + >>> "multilabel" in list(ClassificationTaskNoMultilabel) + False + """ + + @staticmethod + def _name() -> str: + return "Classification" + + BINARY = "binary" + MULTICLASS = "multiclass"