Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 9, 2022
1 parent 3a7f051 commit 17fa7ec
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
7 changes: 4 additions & 3 deletions tests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from tests.helpers import seed_all
from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import F1Score, FBeta, Metric
from torchmetrics.functional import f1_score as f1_score_pl, fbeta
from torchmetrics.functional import f1_score as f1_score_pl
from torchmetrics.functional import fbeta
from torchmetrics.utilities.checks import _input_format_classification
from torchmetrics.utilities.enums import AverageMethod

Expand Down Expand Up @@ -183,7 +184,7 @@ def test_no_support(metric_class, metric_fn):
[
(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)),
(F1Score, f1_score_pl),
]
],
)
@pytest.mark.parametrize(
"ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))]
Expand Down Expand Up @@ -441,7 +442,7 @@ def test_top_k(
"metric_class, metric_functional, sk_fn",
[
(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)),
(F1Score, f1_score_pl, f1_score)
(F1Score, f1_score_pl, f1_score),
],
)
def test_same_input(metric_class, metric_functional, sk_fn, average, ignore_index):
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from torchmetrics.classification import ( # noqa: E402, F401
AUC,
AUROC,
F1Score,
F1,
ROC,
Accuracy,
Expand All @@ -38,6 +37,7 @@
CalibrationError,
CohenKappa,
ConfusionMatrix,
F1Score,
FBeta,
HammingDistance,
Hinge,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def f1_score(
"""
return fbeta(preds, target, 1.0, average, mdmc_average, ignore_index, num_classes, threshold, top_k, multiclass)


def f1(
preds: Tensor,
target: Tensor,
Expand Down

0 comments on commit 17fa7ec

Please sign in to comment.