Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 18, 2021
1 parent a2a2b94 commit 6d7b761
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@

import numpy as np
import pytest
from sklearn.metrics import average_precision_score as sk_average_precision_score
import torch
from sklearn.metrics import average_precision_score as sk_average_precision_score
from torch import tensor

from tests.classification.inputs import _input_binary_prob
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, MetricTester, BATCH_SIZE
from tests.helpers.testers import BATCH_SIZE, NUM_CLASSES, MetricTester
from torchmetrics.classification.average_precision import AveragePrecision
from torchmetrics.functional import average_precision

Expand Down Expand Up @@ -123,10 +123,10 @@ def test_average_precision(scores, target, expected_score):

@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes", [
(torch.randn(BATCH_SIZE, NUM_CLASSES),
torch.randint(0, 2, (BATCH_SIZE, NUM_CLASSES)),
_sk_avg_prec_multiclass_prob,
NUM_CLASSES),
(
torch.randn(BATCH_SIZE, NUM_CLASSES), torch.randint(0, 2, (BATCH_SIZE, NUM_CLASSES)),
_sk_avg_prec_multiclass_prob, NUM_CLASSES
),
]
)
def test_average_precision_multilabel(preds, target, sk_metric, num_classes):
Expand Down

0 comments on commit 6d7b761

Please sign in to comment.