Skip to content

Commit

Permalink
modify object F1Score
Browse files Browse the repository at this point in the history
  • Loading branch information
cuent committed Jan 9, 2022
1 parent 23785c6 commit de79133
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,10 @@ ConfusionMatrix
.. autoclass:: torchmetrics.ConfusionMatrix
:noindex:

F1
F1Score
~~

.. autoclass:: torchmetrics.F1
.. autoclass:: torchmetrics.F1Score
:noindex:

FBeta
Expand Down
16 changes: 8 additions & 8 deletions tests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import F1, FBeta, Metric
from torchmetrics import F1Score, FBeta, Metric
from torchmetrics.functional import f1_score as f1_score_pl, fbeta
from torchmetrics.utilities.checks import _input_format_classification
from torchmetrics.utilities.enums import AverageMethod
Expand Down Expand Up @@ -93,7 +93,7 @@ def _sk_fbeta_f1_multidim_multiclass(
"metric_class, metric_fn",
[
(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)),
(F1, f1_score_pl),
(F1Score, f1_score_pl),
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -129,7 +129,7 @@ def test_wrong_params(metric_class, metric_fn, average, mdmc_average, num_classe
"metric_class, metric_fn",
[
(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)),
(F1, f1_score_pl),
(F1Score, f1_score_pl),
],
)
def test_zero_division(metric_class, metric_fn):
Expand All @@ -151,7 +151,7 @@ def test_zero_division(metric_class, metric_fn):
"metric_class, metric_fn",
[
(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)),
(F1, f1_score_pl),
(F1Score, f1_score_pl),
],
)
def test_no_support(metric_class, metric_fn):
Expand All @@ -178,7 +178,7 @@ def test_no_support(metric_class, metric_fn):
assert result_cl == result_fn == 0


@pytest.mark.parametrize("metric_class, metric_fn", [(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)), (F1, f1_score_pl)])
@pytest.mark.parametrize("metric_class, metric_fn", [(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)), (F1Score, f1_score_pl)])
@pytest.mark.parametrize(
"ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))]
)
Expand All @@ -202,7 +202,7 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected):

@pytest.mark.parametrize(
"metric_class, metric_fn, sk_fn",
[(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1, f1_score_pl, f1_score)],
[(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1Score, f1_score_pl, f1_score)],
)
@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"])
@pytest.mark.parametrize("ignore_index", [None, 0])
Expand Down Expand Up @@ -388,7 +388,7 @@ def test_fbeta_f1_differentiability(
"metric_class, metric_fn",
[
(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)),
(F1, fbeta),
(F1Score, fbeta),
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -430,7 +430,7 @@ def test_top_k(
@pytest.mark.parametrize("average", ["micro", "macro", "weighted"])
@pytest.mark.parametrize(
"metric_class, metric_functional, sk_fn",
[(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1, f1_score_pl, f1_score)],
[(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1Score, f1_score_pl, f1_score)],
)
def test_same_input(metric_class, metric_functional, sk_fn, average, ignore_index):
preds = _input_miss_class.preds
Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torchmetrics.classification import ( # noqa: E402, F401
AUC,
AUROC,
F1,
F1Score,
ROC,
Accuracy,
AveragePrecision,
Expand Down Expand Up @@ -103,7 +103,7 @@
"CosineSimilarity",
"TweedieDevianceScore",
"ExplainedVariance",
"F1",
"F1Score",
"FBeta",
"HammingDistance",
"Hinge",
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401
from torchmetrics.classification.cohen_kappa import CohenKappa # noqa: F401
from torchmetrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401
from torchmetrics.classification.f_beta import F1, FBeta # noqa: F401
from torchmetrics.classification.f_beta import F1Score, FBeta # noqa: F401
from torchmetrics.classification.hamming_distance import HammingDistance # noqa: F401
from torchmetrics.classification.hinge import Hinge # noqa: F401
from torchmetrics.classification.iou import IoU # noqa: F401
Expand Down
6 changes: 3 additions & 3 deletions torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def compute(self) -> Tensor:
return _fbeta_compute(tp, fp, tn, fn, self.beta, self.ignore_index, self.average, self.mdmc_reduce)


class F1(FBeta):
class F1Score(FBeta):
"""Computes F1 metric. F1 metrics correspond to a harmonic mean of the precision and recall scores.
Works with binary, multiclass, and multilabel data. Accepts logits or probabilities from a model
Expand Down Expand Up @@ -261,10 +261,10 @@ class F1(FBeta):
Example:
>>> from torchmetrics import F1
>>> from torchmetrics import F1Score
>>> target = torch.tensor([0, 1, 2, 0, 1, 2])
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
>>> f1 = F1(num_classes=3)
>>> f1 = F1Score(num_classes=3)
>>> f1(preds, target)
tensor(0.3333)
"""
Expand Down

0 comments on commit de79133

Please sign in to comment.