Skip to content

Commit

Permalink
Update test_tracker (Lightning-AI#1409)
Browse files Browse the repository at this point in the history
* Add test configuration for test_tracker with a `MetricTracker` of `MultioutputWrapper`
* Change the test to support the case where the values returned by `MetricTracker` are lists
  • Loading branch information
ValerianRey committed Dec 31, 2022
1 parent e919d57 commit 4f18b4e
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion tests/unittests/wrappers/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
MulticlassPrecision,
MulticlassRecall,
)
from torchmetrics.wrappers import MetricTracker
from torchmetrics.wrappers import MetricTracker, MultioutputWrapper
from unittests.helpers import seed_all

seed_all(42)
Expand Down Expand Up @@ -95,6 +95,11 @@ def test_raises_error_if_increment_not_called(method, method_input):
(torch.randn(50), torch.randn(50)),
[False, False],
),
(
MultioutputWrapper(MeanSquaredError(), num_outputs=2),
(torch.randn(50, 2), torch.randn(50, 2)),
[False, False],
),
],
)
def test_tracker(base_metric, metric_input, maximize):
Expand All @@ -113,6 +118,9 @@ def test_tracker(base_metric, metric_input, maximize):
if isinstance(val, dict):
for v in val.values():
assert v != 0.0
elif isinstance(val, list):
for v in val:
assert v != 0.0
else:
assert val != 0.0
assert tracker.n_steps == i + 1
Expand All @@ -123,6 +131,9 @@ def test_tracker(base_metric, metric_input, maximize):
if isinstance(all_computed_val, dict):
for v in all_computed_val.values():
assert v.numel() == 5
elif isinstance(all_computed_val, list):
for v in all_computed_val:
assert v.numel() == 5
else:
assert all_computed_val.numel() == 5

Expand All @@ -132,6 +143,10 @@ def test_tracker(base_metric, metric_input, maximize):
for v, i in zip(val.values(), idx.values()):
assert v != 0.0
assert i in list(range(5))
elif isinstance(val, list):
for v, i in zip(val, idx):
assert v != 0.0
assert i in list(range(5))
else:
assert val != 0.0
assert idx in list(range(5))
Expand Down

0 comments on commit 4f18b4e

Please sign in to comment.