diff --git a/requirements/segmentation_test.txt b/requirements/segmentation_test.txt index d1f12efc7a1..fff5018b029 100644 --- a/requirements/segmentation_test.txt +++ b/requirements/segmentation_test.txt @@ -2,4 +2,4 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment scipy >1.0.0, <1.15.0 -monai ==1.3.2 +monai ==1.4.0 diff --git a/tests/unittests/_helpers/testers.py b/tests/unittests/_helpers/testers.py index c5a69077f3c..1b46d6f237f 100644 --- a/tests/unittests/_helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -37,7 +37,7 @@ def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: O ref_result.detach().cpu().numpy() if isinstance(ref_result, Tensor) else ref_result, atol=atol, equal_nan=True, - ) + ), f"tm_result: {tm_result}, ref_result: {ref_result}" # multi output compare elif isinstance(tm_result, Sequence): for pl_res, ref_res in zip(tm_result, ref_result): @@ -50,7 +50,7 @@ def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: O ref_result.detach().cpu().numpy() if isinstance(ref_result, Tensor) else ref_result, atol=atol, equal_nan=True, - ) + ), f"tm_result: {tm_result}, ref_result: {ref_result}" else: raise ValueError("Unknown format for comparison") diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index 3f8acec842a..f5ec310f96d 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -51,10 +51,10 @@ def _reference_generalized_dice( if input_format == "index": preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) - val = compute_generalized_dice(preds, target, include_background=include_background) + val = compute_generalized_dice(preds, target, include_background=include_background, sum_over_classes=True) if reduce: val = val.mean() - return val + return val.squeeze() @pytest.mark.parametrize(