Skip to content

Commit

Permalink
[fbsync] port tests for transforms.ScaleJitter (#8001)
Browse files Browse the repository at this point in the history
Reviewed By: vmoens

Differential Revision: D50789106

fbshipit-source-id: 85b9b9e072e0e22e0c502c51307a276b0af9ed6d
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Oct 30, 2023
1 parent 537f166 commit bb644e8
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 28 deletions.
28 changes: 0 additions & 28 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,34 +519,6 @@ def test__transform(self, mocker):
assert isinstance(output_masks, tv_tensors.Mask)


class TestScaleJitter:
def test__get_params(self):
canvas_size = (24, 32)
target_size = (16, 12)
scale_range = (0.5, 1.5)

transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range)

sample = make_image(canvas_size)

n_samples = 5
for _ in range(n_samples):

params = transform._get_params([sample])

assert "size" in params
size = params["size"]

assert isinstance(size, tuple) and len(size) == 2
height, width = size

r_min = min(target_size[1] / canvas_size[0], target_size[0] / canvas_size[1]) * scale_range[0]
r_max = min(target_size[1] / canvas_size[0], target_size[0] / canvas_size[1]) * scale_range[1]

assert int(canvas_size[0] * r_min) <= height <= int(canvas_size[0] * r_max)
assert int(canvas_size[1] * r_min) <= width <= int(canvas_size[1] * r_max)


class TestRandomShortestSize:
@pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)])
def test__get_params(self, min_size, max_size):
Expand Down
38 changes: 38 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -4788,3 +4788,41 @@ def test_transform(self, make_input, dtype, device):
),
make_input(dtype=dtype, device=device),
)


class TestScaleJitter:
# Tests are light because this largely relies on the already tested `resize` kernels.

INPUT_SIZE = (17, 11)
TARGET_SIZE = (12, 13)

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, device):
if make_input is make_image_pil and device != "cpu":
pytest.skip("PIL image tests with parametrization device!='cpu' will degenerate to that anyway.")

check_transform(transforms.ScaleJitter(self.TARGET_SIZE), make_input(self.INPUT_SIZE, device=device))

def test__get_params(self):
input_size = self.INPUT_SIZE
target_size = self.TARGET_SIZE
scale_range = (0.5, 1.5)

transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range)
params = transform._get_params([make_image(input_size)])

assert "size" in params
size = params["size"]

assert isinstance(size, tuple) and len(size) == 2
height, width = size

r_min = min(target_size[1] / input_size[0], target_size[0] / input_size[1]) * scale_range[0]
r_max = min(target_size[1] / input_size[0], target_size[0] / input_size[1]) * scale_range[1]

assert int(input_size[0] * r_min) <= height <= int(input_size[0] * r_max)
assert int(input_size[1] * r_min) <= width <= int(input_size[1] * r_max)

0 comments on commit bb644e8

Please sign in to comment.