From 8cd9fdfcb8845581b994e892dd1a9f181dd37dac Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 11 Sep 2023 23:04:34 +0200 Subject: [PATCH] add new tests for F.adjust_saturation --- test/test_transforms_v2_refactored.py | 42 +++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index cb856d4d798..ddadad8ca54 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -4494,3 +4494,45 @@ def test_correctness_image(self, hue_factor): mae = (actual.float() - expected.float()).abs().mean() assert mae < 2 + + +class TestAdjustSaturation: + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_image(self, dtype, device): + check_kernel(F.adjust_saturation_image, make_image(dtype=dtype, device=device), saturation_factor=0.5) + + def test_kernel_video(self): + check_kernel(F.adjust_saturation_video, make_video(), saturation_factor=0.5) + + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video]) + def test_functional(self, make_input): + check_functional(F.adjust_saturation, make_input(), saturation_factor=0.5) + + @pytest.mark.parametrize( + ("kernel", "input_type"), + [ + (F.adjust_saturation_image, torch.Tensor), + (F._adjust_saturation_image_pil, PIL.Image.Image), + (F.adjust_saturation_image, tv_tensors.Image), + (F.adjust_saturation_video, tv_tensors.Video), + ], + ) + def test_functional_signature(self, kernel, input_type): + check_functional_kernel_signature_match(F.adjust_saturation, kernel=kernel, input_type=input_type) + + def test_functional_error(self): + with pytest.raises(TypeError, match="permitted channel values are 1 or 3"): + F.adjust_saturation(make_image(color_space="RGBA"), saturation_factor=0.5) + + with pytest.raises(ValueError, match="is not non-negative"): + F.adjust_saturation(make_image(), saturation_factor=-1) + + @pytest.mark.parametrize("saturation_factor", [0.1, 0.5, 1.0]) + def test_correctness_image(self, saturation_factor): + image = make_image(dtype=torch.uint8, device="cpu") + + actual = F.adjust_saturation(image, saturation_factor=saturation_factor) + expected = F.to_image(F.adjust_saturation(F.to_pil_image(image), saturation_factor=saturation_factor)) + + assert_close(actual, expected, rtol=0, atol=1)