Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 9, 2022
1 parent 247dd79 commit ab6a5cd
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
17 changes: 13 additions & 4 deletions tests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from tests.helpers import seed_all
from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import F1Score, FBeta, Metric
from torchmetrics.functional import f1_score as f1_score_pl, fbeta
from torchmetrics.functional import f1_score as f1_score_pl
from torchmetrics.functional import fbeta
from torchmetrics.utilities.checks import _input_format_classification
from torchmetrics.utilities.enums import AverageMethod

Expand Down Expand Up @@ -178,7 +179,9 @@ def test_no_support(metric_class, metric_fn):
assert result_cl == result_fn == 0


@pytest.mark.parametrize("metric_class, metric_fn", [(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)), (F1Score, f1_score_pl)])
@pytest.mark.parametrize(
"metric_class, metric_fn", [(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)), (F1Score, f1_score_pl)]
)
@pytest.mark.parametrize(
"ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))]
)
Expand All @@ -202,7 +205,10 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected):

@pytest.mark.parametrize(
"metric_class, metric_fn, sk_fn",
[(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1Score, f1_score_pl, f1_score)],
[
(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)),
(F1Score, f1_score_pl, f1_score),
],
)
@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"])
@pytest.mark.parametrize("ignore_index", [None, 0])
Expand Down Expand Up @@ -430,7 +436,10 @@ def test_top_k(
@pytest.mark.parametrize("average", ["micro", "macro", "weighted"])
@pytest.mark.parametrize(
"metric_class, metric_functional, sk_fn",
[(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1Score, f1_score_pl, f1_score)],
[
(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)),
(F1Score, f1_score_pl, f1_score),
],
)
def test_same_input(metric_class, metric_functional, sk_fn, average, ignore_index):
preds = _input_miss_class.preds
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from torchmetrics.classification import ( # noqa: E402, F401
AUC,
AUROC,
F1Score,
F1,
ROC,
Accuracy,
Expand All @@ -38,6 +37,7 @@
CalibrationError,
CohenKappa,
ConfusionMatrix,
F1Score,
FBeta,
HammingDistance,
Hinge,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def f1_score(
"""
return fbeta(preds, target, 1.0, average, mdmc_average, ignore_index, num_classes, threshold, top_k, multiclass)


def f1(
preds: Tensor,
target: Tensor,
Expand Down

0 comments on commit ab6a5cd

Please sign in to comment.