Skip to content

Commit

Permalink
Merge branch 'master' into ruff/first_line_split
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Feb 24, 2023
2 parents 511b5d2 + 82d5271 commit 34579e1
Show file tree
Hide file tree
Showing 60 changed files with 202 additions and 148 deletions.
10 changes: 9 additions & 1 deletion tests/unittests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

import torch

from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_PROCESSES, DummyMetric, MetricTester # noqa: F401
from unittests.conftest import ( # noqa: F401
BATCH_SIZE,
EXTRA_DIM,
NUM_BATCHES,
NUM_CLASSES,
NUM_PROCESSES,
THRESHOLD,
setup_ddp,
)

_PATH_TESTS = os.path.dirname(__file__)
_PATH_ROOT = os.path.dirname(_PATH_TESTS)
Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/audio/test_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
scale_invariant_signal_distortion_ratio,
signal_noise_ratio,
)
from unittests import BATCH_SIZE, NUM_BATCHES
from unittests.helpers import seed_all
from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from unittests.helpers.testers import MetricTester

seed_all(42)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/audio/test_si_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@

from torchmetrics.audio import ScaleInvariantSignalDistortionRatio
from torchmetrics.functional import scale_invariant_signal_distortion_ratio
from unittests import BATCH_SIZE, NUM_BATCHES
from unittests.helpers import seed_all
from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from unittests.helpers.testers import MetricTester

seed_all(42)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/audio/test_si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@

from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
from torchmetrics.functional import scale_invariant_signal_noise_ratio
from unittests import BATCH_SIZE, NUM_BATCHES
from unittests.helpers import seed_all
from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from unittests.helpers.testers import MetricTester

seed_all(42)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/audio/test_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@

from torchmetrics.audio import SignalNoiseRatio
from torchmetrics.functional import signal_noise_ratio
from unittests import NUM_BATCHES
from unittests.helpers import seed_all
from unittests.helpers.testers import NUM_BATCHES, MetricTester
from unittests.helpers.testers import MetricTester

seed_all(42)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/bases/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch

from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric
from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from unittests import BATCH_SIZE, NUM_BATCHES
from unittests.helpers.testers import MetricTester


def compare_mean(values, weights):
Expand Down
54 changes: 21 additions & 33 deletions tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import sys
from copy import deepcopy
from functools import partial

import pytest
import torch
Expand All @@ -22,14 +23,14 @@
from torchmetrics import Metric
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.exceptions import TorchMetricsUserError
from unittests import NUM_PROCESSES
from unittests.helpers import seed_all
from unittests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum, setup_ddp
from unittests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum

seed_all(42)


def _test_ddp_sum(rank, worldsize):
setup_ddp(rank, worldsize)
def _test_ddp_sum(rank: int, worldsize: int = NUM_PROCESSES) -> None:
dummy = DummyMetric()
dummy._reductions = {"foo": torch.sum}
dummy.foo = tensor(1)
Expand All @@ -38,8 +39,7 @@ def _test_ddp_sum(rank, worldsize):
assert dummy.foo == worldsize


def _test_ddp_cat(rank, worldsize):
setup_ddp(rank, worldsize)
def _test_ddp_cat(rank: int, worldsize: int = NUM_PROCESSES) -> None:
dummy = DummyMetric()
dummy._reductions = {"foo": torch.cat}
dummy.foo = [tensor([1])]
Expand All @@ -48,8 +48,7 @@ def _test_ddp_cat(rank, worldsize):
assert torch.all(torch.eq(dummy.foo, tensor([1, 1])))


def _test_ddp_sum_cat(rank, worldsize):
setup_ddp(rank, worldsize)
def _test_ddp_sum_cat(rank: int, worldsize: int = NUM_PROCESSES) -> None:
dummy = DummyMetric()
dummy._reductions = {"foo": torch.cat, "bar": torch.sum}
dummy.foo = [tensor([1])]
Expand All @@ -60,29 +59,24 @@ def _test_ddp_sum_cat(rank, worldsize):
assert dummy.bar == worldsize


def _test_ddp_gather_uneven_tensors(rank, worldsize):
setup_ddp(rank, worldsize)
def _test_ddp_gather_uneven_tensors(rank: int, worldsize: int = NUM_PROCESSES) -> None:
tensor = torch.ones(rank)
result = gather_all_tensors(tensor)
assert len(result) == worldsize
for idx in range(worldsize):
assert len(result[idx]) == idx
assert (result[idx] == torch.ones_like(result[idx])).all()


def _test_ddp_gather_uneven_tensors_multidim(rank, worldsize):
setup_ddp(rank, worldsize)
def _test_ddp_gather_uneven_tensors_multidim(rank: int, worldsize: int = NUM_PROCESSES) -> None:
tensor = torch.ones(rank + 1, 2 - rank)
result = gather_all_tensors(tensor)
assert len(result) == worldsize
for idx in range(worldsize):
val = result[idx]
assert val.shape == (idx + 1, 2 - idx)
assert (val == torch.ones_like(val)).all()


def _test_ddp_compositional_tensor(rank, worldsize):
setup_ddp(rank, worldsize)
def _test_ddp_compositional_tensor(rank: int, worldsize: int = NUM_PROCESSES) -> None:
dummy = DummyMetricSum()
dummy._reductions = {"x": torch.sum}
dummy = dummy.clone() + dummy.clone()
Expand All @@ -104,12 +98,10 @@ def _test_ddp_compositional_tensor(rank, worldsize):
],
)
def test_ddp(process):
torch.multiprocessing.spawn(process, args=(2,), nprocs=2)
pytest.pool.map(process, range(NUM_PROCESSES))


def _test_non_contiguous_tensors(rank, worldsize):
setup_ddp(rank, worldsize)

def _test_non_contiguous_tensors(rank):
class DummyCatMetric(Metric):
full_state_update = True

Expand All @@ -131,12 +123,10 @@ def compute(self):
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_non_contiguous_tensors():
"""Test that gather_all operation works for non contiguous tensors."""
torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2,), nprocs=2)

pytest.pool.map(_test_non_contiguous_tensors, range(NUM_PROCESSES))

def _test_state_dict_is_synced(rank, worldsize, tmpdir):
setup_ddp(rank, worldsize)

def _test_state_dict_is_synced(rank, tmpdir):
class DummyCatMetric(Metric):
full_state_update = True

Expand Down Expand Up @@ -241,11 +231,10 @@ def reload_state_dict(state_dict, expected_x, expected_c):
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_state_dict_is_synced(tmpdir):
"""Tests that metrics are synced while creating the state dict but restored after to continue accumulation."""
torch.multiprocessing.spawn(_test_state_dict_is_synced, args=(2, tmpdir), nprocs=2)
pytest.pool.map(partial(_test_state_dict_is_synced, tmpdir=tmpdir), range(NUM_PROCESSES))


def _test_sync_on_compute_tensor_state(rank, worldsize, sync_on_compute):
setup_ddp(rank, worldsize)
def _test_sync_on_compute_tensor_state(rank, sync_on_compute):
dummy = DummyMetricSum(sync_on_compute=sync_on_compute)
dummy.update(tensor(rank + 1))
val = dummy.compute()
Expand All @@ -255,13 +244,13 @@ def _test_sync_on_compute_tensor_state(rank, worldsize, sync_on_compute):
assert val == rank + 1


def _test_sync_on_compute_list_state(rank, worldsize, sync_on_compute):
setup_ddp(rank, worldsize)
def _test_sync_on_compute_list_state(rank, sync_on_compute):
dummy = DummyListMetric(sync_on_compute=sync_on_compute)
dummy.update(tensor(rank + 1))
val = dummy.compute()
if sync_on_compute:
assert torch.allclose(val, tensor([1, 2]))
assert val.sum() == 3
assert torch.allclose(val, tensor([1, 2])) or torch.allclose(val, tensor([2, 1]))
else:
assert val == [tensor(rank + 1)]

Expand All @@ -271,11 +260,10 @@ def _test_sync_on_compute_list_state(rank, worldsize, sync_on_compute):
@pytest.mark.parametrize("test_func", [_test_sync_on_compute_list_state, _test_sync_on_compute_tensor_state])
def test_sync_on_compute(sync_on_compute, test_func):
"""Test that syncronization of states can be enabled and disabled for compute."""
torch.multiprocessing.spawn(test_func, args=(2, sync_on_compute), nprocs=2)
pytest.pool.map(partial(test_func, sync_on_compute=sync_on_compute), range(NUM_PROCESSES))


def _test_sync_with_empty_lists(rank, worldsize):
setup_ddp(rank, worldsize)
def _test_sync_with_empty_lists(rank):
dummy = DummyListMetric()
val = dummy.compute()
assert val == []
Expand All @@ -284,4 +272,4 @@ def _test_sync_with_empty_lists(rank, worldsize):
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_sync_with_empty_lists():
"""Test that syncronization of states can be enabled and disabled for compute."""
torch.multiprocessing.spawn(_test_sync_with_empty_lists, args=(2,), nprocs=2)
pytest.pool.map(_test_sync_with_empty_lists, range(NUM_PROCESSES))
2 changes: 1 addition & 1 deletion tests/unittests/classification/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import torch
from torch import Tensor

from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES
from unittests.helpers import seed_all
from unittests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES

seed_all(1)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@

from torchmetrics.classification.accuracy import BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy
from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy, multilabel_accuracy
from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
from torchmetrics.classification.auroc import BinaryAUROC, MulticlassAUROC, MultilabelAUROC
from torchmetrics.functional.classification.auroc import binary_auroc, multiclass_auroc, multilabel_auroc
from torchmetrics.functional.classification.roc import binary_roc
from unittests import NUM_CLASSES
from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@
multilabel_average_precision,
)
from torchmetrics.functional.classification.precision_recall_curve import binary_precision_recall_curve
from unittests import NUM_CLASSES
from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/classification/test_calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
multiclass_calibration_error,
)
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_9
from unittests import NUM_CLASSES
from unittests.classification.inputs import _binary_cases, _multiclass_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/classification/test_cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@

from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, MulticlassCohenKappa
from torchmetrics.functional.classification.cohen_kappa import binary_cohen_kappa, multiclass_cohen_kappa
from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification.inputs import _binary_cases, _multiclass_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/classification/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
multiclass_confusion_matrix,
multilabel_confusion_matrix,
)
from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/classification/test_exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@

from torchmetrics.classification.exact_match import MulticlassExactMatch, MultilabelExactMatch
from torchmetrics.functional.classification.exact_match import multiclass_exact_match, multilabel_exact_match
from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification.inputs import _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index
from unittests.helpers.testers import MetricTester, inject_ignore_index

seed_all(42)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@
multilabel_f1_score,
multilabel_fbeta_score,
)
from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/classification/test_hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
multiclass_hamming_distance,
multilabel_hamming_distance,
)
from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/classification/test_hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@

from torchmetrics.classification.hinge import BinaryHingeLoss, MulticlassHingeLoss
from torchmetrics.functional.classification.hinge import binary_hinge_loss, multiclass_hinge_loss
from unittests import NUM_CLASSES
from unittests.classification.inputs import _binary_cases, _multiclass_cases
from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

torch.manual_seed(42)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/classification/test_jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
multiclass_jaccard_index,
multilabel_jaccard_index,
)
from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index


def _sklearn_jaccard_index_binary(preds, target, ignore_index=None):
Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/classification/test_matthews_corrcoef.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
multiclass_matthews_corrcoef,
multilabel_matthews_corrcoef,
)
from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@
multilabel_precision,
multilabel_recall,
)
from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
Loading

0 comments on commit 34579e1

Please sign in to comment.