From 2a32b9c814b0fde72dbeed91e4f282bf81aeb9f9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 9 Jan 2022 01:52:49 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/classification/test_f_beta.py | 17 +++++++++++++---- torchmetrics/__init__.py | 2 +- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index 9b7bc913795..0839c81c532 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -33,7 +33,8 @@ from tests.helpers import seed_all from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester from torchmetrics import F1Score, FBeta, Metric -from torchmetrics.functional import f1_score as f1_score_pl, fbeta +from torchmetrics.functional import f1_score as f1_score_pl +from torchmetrics.functional import fbeta from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import AverageMethod @@ -178,7 +179,9 @@ def test_no_support(metric_class, metric_fn): assert result_cl == result_fn == 0 -@pytest.mark.parametrize("metric_class, metric_fn", [(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)), (F1Score, f1_score_pl)]) +@pytest.mark.parametrize( + "metric_class, metric_fn", [(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)), (F1Score, f1_score_pl)] +) @pytest.mark.parametrize( "ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))] ) @@ -202,7 +205,10 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected): @pytest.mark.parametrize( "metric_class, metric_fn, sk_fn", - [(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1Score, f1_score_pl, f1_score)], + [ + (partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), + (F1Score, f1_score_pl, f1_score), + ], ) @pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) @pytest.mark.parametrize("ignore_index", [None, 0]) @@ -430,7 +436,10 @@ def test_top_k( @pytest.mark.parametrize("average", ["micro", "macro", "weighted"]) @pytest.mark.parametrize( "metric_class, metric_functional, sk_fn", - [(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1Score, f1_score_pl, f1_score)], + [ + (partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), + (F1Score, f1_score_pl, f1_score), + ], ) def test_same_input(metric_class, metric_functional, sk_fn, average, ignore_index): preds = _input_miss_class.preds diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 49dd5e76bca..cfd3ba6b62e 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -17,7 +17,6 @@ from torchmetrics.classification import ( # noqa: E402, F401 AUC, AUROC, - F1Score, ROC, Accuracy, AveragePrecision, @@ -27,6 +26,7 @@ CalibrationError, CohenKappa, ConfusionMatrix, + F1Score, FBeta, HammingDistance, Hinge,