Skip to content

Commit

Permalink
add new tests for F.adjust_saturation
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Sep 11, 2023
1 parent c959590 commit 8cd9fdf
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -4494,3 +4494,45 @@ def test_correctness_image(self, hue_factor):

mae = (actual.float() - expected.float()).abs().mean()
assert mae < 2


class TestAdjustSaturation:
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image(self, dtype, device):
check_kernel(F.adjust_saturation_image, make_image(dtype=dtype, device=device), saturation_factor=0.5)

def test_kernel_video(self):
check_kernel(F.adjust_saturation_video, make_video(), saturation_factor=0.5)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
def test_functional(self, make_input):
check_functional(F.adjust_saturation, make_input(), saturation_factor=0.5)

@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.adjust_saturation_image, torch.Tensor),
(F._adjust_saturation_image_pil, PIL.Image.Image),
(F.adjust_saturation_image, tv_tensors.Image),
(F.adjust_saturation_video, tv_tensors.Video),
],
)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.adjust_saturation, kernel=kernel, input_type=input_type)

def test_functional_error(self):
with pytest.raises(TypeError, match="permitted channel values are 1 or 3"):
F.adjust_saturation(make_image(color_space="RGBA"), saturation_factor=0.5)

with pytest.raises(ValueError, match="is not non-negative"):
F.adjust_saturation(make_image(), saturation_factor=-1)

@pytest.mark.parametrize("saturation_factor", [0.1, 0.5, 1.0])
def test_correctness_image(self, saturation_factor):
image = make_image(dtype=torch.uint8, device="cpu")

actual = F.adjust_saturation(image, saturation_factor=saturation_factor)
expected = F.to_image(F.adjust_saturation(F.to_pil_image(image), saturation_factor=saturation_factor))

assert_close(actual, expected, rtol=0, atol=1)

0 comments on commit 8cd9fdf

Please sign in to comment.