From 17fa7ec8a1814541cb576c9ca892ec6fefea21e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 9 Jan 2022 02:29:38 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/classification/test_f_beta.py | 7 ++++--- torchmetrics/__init__.py | 2 +- torchmetrics/functional/classification/f_beta.py | 1 + 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index 02b4241de21..582ce6847e9 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -33,7 +33,8 @@ from tests.helpers import seed_all from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester from torchmetrics import F1Score, FBeta, Metric -from torchmetrics.functional import f1_score as f1_score_pl, fbeta +from torchmetrics.functional import f1_score as f1_score_pl +from torchmetrics.functional import fbeta from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import AverageMethod @@ -183,7 +184,7 @@ def test_no_support(metric_class, metric_fn): [ (partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)), (F1Score, f1_score_pl), - ] + ], ) @pytest.mark.parametrize( "ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))] @@ -441,7 +442,7 @@ def test_top_k( "metric_class, metric_functional, sk_fn", [ (partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), - (F1Score, f1_score_pl, f1_score) + (F1Score, f1_score_pl, f1_score), ], ) def test_same_input(metric_class, metric_functional, sk_fn, average, ignore_index): diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 6dbd9bddbe0..183ba3cdcba 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -27,7 +27,6 @@ from torchmetrics.classification import ( # noqa: E402, F401 AUC, AUROC, - F1Score, F1, ROC, Accuracy, @@ -38,6 +37,7 @@ CalibrationError, CohenKappa, ConfusionMatrix, + F1Score, FBeta, HammingDistance, Hinge, diff --git a/torchmetrics/functional/classification/f_beta.py b/torchmetrics/functional/classification/f_beta.py index 97ccd2255d3..2640be0c34b 100644 --- a/torchmetrics/functional/classification/f_beta.py +++ b/torchmetrics/functional/classification/f_beta.py @@ -350,6 +350,7 @@ def f1_score( """ return fbeta(preds, target, 1.0, average, mdmc_average, ignore_index, num_classes, threshold, top_k, multiclass) + def f1( preds: Tensor, target: Tensor,