diff --git a/tests/unittests/wrappers/test_tracker.py b/tests/unittests/wrappers/test_tracker.py index a2981d2d126..084d39eeee3 100644 --- a/tests/unittests/wrappers/test_tracker.py +++ b/tests/unittests/wrappers/test_tracker.py @@ -22,7 +22,7 @@ MulticlassPrecision, MulticlassRecall, ) -from torchmetrics.wrappers import MetricTracker +from torchmetrics.wrappers import MetricTracker, MultioutputWrapper from unittests.helpers import seed_all seed_all(42) @@ -95,6 +95,11 @@ def test_raises_error_if_increment_not_called(method, method_input): (torch.randn(50), torch.randn(50)), [False, False], ), + ( + MultioutputWrapper(MeanSquaredError(), num_outputs=2), + (torch.randn(50, 2), torch.randn(50, 2)), + [False, False], + ), ], ) def test_tracker(base_metric, metric_input, maximize): @@ -113,6 +118,9 @@ def test_tracker(base_metric, metric_input, maximize): if isinstance(val, dict): for v in val.values(): assert v != 0.0 + elif isinstance(val, list): + for v in val: + assert v != 0.0 else: assert val != 0.0 assert tracker.n_steps == i + 1 @@ -123,6 +131,9 @@ def test_tracker(base_metric, metric_input, maximize): if isinstance(all_computed_val, dict): for v in all_computed_val.values(): assert v.numel() == 5 + elif isinstance(all_computed_val, list): + for v in all_computed_val: + assert v.numel() == 5 else: assert all_computed_val.numel() == 5 @@ -132,6 +143,10 @@ def test_tracker(base_metric, metric_input, maximize): for v, i in zip(val.values(), idx.values()): assert v != 0.0 assert i in list(range(5)) + elif isinstance(val, list): + for v, i in zip(val, idx): + assert v != 0.0 + assert i in list(range(5)) else: assert val != 0.0 assert idx in list(range(5))