Skip to content

Commit

Permalink
Merge branch 'main' into port/center-crop
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Sep 8, 2023
2 parents d10fffd + d78b462 commit 81bb6d5
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
25 changes: 18 additions & 7 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,12 @@ def adapt_fill(value, *, dtype):
return value

max_value = get_max_value(dtype)
value_type = float if dtype.is_floating_point else int

if isinstance(value, (int, float)):
return type(value)(value * max_value)
return value_type(value * max_value)
elif isinstance(value, (list, tuple)):
return type(value)(type(v)(v * max_value) for v in value)
return type(value)(value_type(v * max_value) for v in value)
else:
raise ValueError(f"fill should be an int or float, or a list or tuple of the former, but got '{value}'.")

Expand Down Expand Up @@ -417,6 +418,10 @@ def affine_bounding_boxes(bounding_boxes):
)


# turns all warnings into errors for this module
pytestmark = pytest.mark.filterwarnings("error")


class TestResize:
INPUT_SIZE = (17, 11)
OUTPUT_SIZES = [17, [17], (17,), [12, 13], (12, 13)]
Expand Down Expand Up @@ -2577,15 +2582,19 @@ def test_functional_image_correctness(self, kwargs):
def test_transform(self, param, value, make_input):
input = make_input(self.INPUT_SIZE)

kwargs = {param: value}
if param == "fill":
# 1. size is required
# 2. the fill parameter only has an affect if we need padding
kwargs["size"] = [s + 4 for s in self.INPUT_SIZE]

if isinstance(input, tv_tensors.Mask) and isinstance(value, (tuple, list)):
pytest.skip("F.pad_mask doesn't support non-scalar fill.")

kwargs = dict(
# 1. size is required
# 2. the fill parameter only has an affect if we need padding
size=[s + 4 for s in self.INPUT_SIZE],
fill=adapt_fill(value, dtype=input.dtype if isinstance(input, torch.Tensor) else torch.uint8),
)
else:
kwargs = {param: value}

check_transform(
transforms.RandomCrop(**kwargs, pad_if_needed=True),
input,
Expand Down Expand Up @@ -3478,6 +3487,8 @@ def test_transform_errors(self):
def test_image_correctness(self, padding, padding_mode, fill, fn):
image = make_image(dtype=torch.uint8, device="cpu")

fill = adapt_fill(fill, dtype=torch.uint8)

actual = fn(image, padding=padding, padding_mode=padding_mode, fill=fill)
expected = F.to_image(F.pad(F.to_pil_image(image), padding=padding, padding_mode=padding_mode, fill=fill))

Expand Down
4 changes: 3 additions & 1 deletion torchvision/models/detection/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ def _resize_image_and_masks(
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if torchvision._is_tracing():
im_shape = _get_shape_onnx(image)
else:
elif torch.jit.is_scripting():
im_shape = torch.tensor(image.shape[-2:])
else:
im_shape = image.shape[-2:]

size: Optional[List[int]] = None
scale_factor: Optional[float] = None
Expand Down
6 changes: 5 additions & 1 deletion torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,11 @@ def _pad_with_vector_fill(

output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
left, right, top, bottom = torch_padding
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)

# We are creating the tensor in the autodetected dtype first and convert to the right one after to avoid an implicit
# float -> int conversion. That happens for example for the valid input of a uint8 image with floating point fill
# value.
fill = torch.tensor(fill, device=image.device).to(dtype=image.dtype).reshape(-1, 1, 1)

if top > 0:
output[..., :top, :] = fill
Expand Down

0 comments on commit 81bb6d5

Please sign in to comment.