Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scriptable Resize Added #1666

Closed
wants to merge 2 commits into from
Closed

Scriptable Resize Added #1666

wants to merge 2 commits into from

Conversation

surgan12
Copy link
Contributor

@surgan12 surgan12 commented Dec 15, 2019

This PR adds torchscriptable resize as a part of #1375
cc : @fmassa

@codecov-io
Copy link

codecov-io commented Dec 15, 2019

Codecov Report

Merging #1666 into master will decrease coverage by 66.02%.
The diff coverage is 0%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master   #1666       +/-   ##
==========================================
- Coverage   66.02%      0%   -66.03%     
==========================================
  Files          92      92               
  Lines        7331    7348       +17     
  Branches     1107    1111        +4     
==========================================
- Hits         4840       0     -4840     
- Misses       2175    7348     +5173     
+ Partials      316       0      -316
Impacted Files Coverage Δ
torchvision/transforms/functional_tensor.py 0% <0%> (-62.69%) ⬇️
torchvision/models/quantization/__init__.py 0% <0%> (-100%) ⬇️
torchvision/ops/boxes.py 0% <0%> (-100%) ⬇️
torchvision/datasets/samplers/__init__.py 0% <0%> (-100%) ⬇️
torchvision/ops/_register_onnx_ops.py 0% <0%> (-100%) ⬇️
torchvision/models/__init__.py 0% <0%> (-100%) ⬇️
torchvision/models/segmentation/deeplabv3.py 0% <0%> (-100%) ⬇️
torchvision/io/__init__.py 0% <0%> (-100%) ⬇️
torchvision/models/segmentation/fcn.py 0% <0%> (-100%) ⬇️
torchvision/ops/__init__.py 0% <0%> (-100%) ⬇️
... and 83 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 2d7c066...ced792f. Read the comment docs.

@surgan12
Copy link
Contributor Author

@fmassa could you review this please.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay in reviewing.

I have some comments, let me know what you think.

Also, there are some things I'm not yet sure about, so it might need some discussion

else:
out_img = Fn.interpolate(img.unsqueeze(0), size=size, mode=interpolation)

return(out_img.clamp(min=0, max=255).squeeze(0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should be clamping the input to 0-255 in the function. Although this makes it match the behavior of the Pillow implementation, this only works if the input is in 0-255, which is rarely the case for floating point values.

My take is that we should remove the clamp in this function.

Also, can you remove the parenthesis and add a space after the return?

if w < h:
ow = size
oh = int(size * h / w)
out_img = Fn.interpolate(img.unsqueeze(0), size=(oh, ow), mode=interpolation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes that the input is not batched (but it could potentially be, like for videos).
Instead, I think we should unsqueeze only if it's a 3d tensor, and in this case we should squeeze back.

ow = int(size * w / h)
out_img = Fn.interpolate(img.unsqueeze(0), size=(oh, ow), mode=interpolation)
else:
out_img = Fn.interpolate(img.unsqueeze(0), size=size, mode=interpolation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine as is, but I think we can simplify a bit the code so that we only need to call interpolate in a single codepath.

A solution would be to create the size in the int case to be a tuple. Something like

    if isinstance(size, int):
        w, h = img.shape[2], img.shape[1]
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            size = (oh, ow)
    else:
            oh = size
            ow = int(size * w / h)
            size = (oh, ow)
    out_img = Fn.interpolate(..., size=size, mode=interpolation)

@@ -219,3 +220,40 @@ def ten_crop(img, size, vertical_flip=False):
def _blend(img1, img2, ratio):
bound = 1 if img1.dtype.is_floating_point else 255
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)


def resize(img, size, interpolation="bilinear"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something that I still need to figure out: should we use the old PIL-based interpolation with int, or the torch-based one? I'm not sure.

@@ -138,6 +138,36 @@ def test_ten_crop(self):
(transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))

def test_resize(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a test for scriptability?

@surgan12
Copy link
Contributor Author

@fmassa, I am getting error while scripting nn.functional.interpolate

`RuntimeError: isIntList() INTERNAL ASSERT FAILED at /pytorch/aten/src/ATen/core/ivalue_inl.h:544, please report a bug to PyTorch. Expected IntList but got DoubleList
The above operation failed in interpreter, with the following stack trace:
at tt2.py:19:19

@torch.jit.script
def script_resize(input, output_size, interpolation):
# type: (Tensor, BroadcastingList2[int], str) -> Tensor
return Fn.interpolate(input, output_size, mode=interpolation)
~~~~~~~~~~~~~~ <--- HERE
`
Can you help me fix this ?

@fmassa
Copy link
Member

fmassa commented Dec 19, 2019

@surgan12 you might just need to specify that size is an int somewhere. But maybe this is an internal bug in PyTorch, which PyTorch version are you using?

@surgan12
Copy link
Contributor Author

@fmassa I am using 1.3.0.
Also, in interpolate we have both size as an int or sequence as an option. Let me know what you think ?

@fmassa
Copy link
Member

fmassa commented Dec 19, 2019

having both size and int is fine. I think you might need to use a more recent version of PyTorch, like a recent nightly. This might have been fixed in PyTorch already

@surgan12
Copy link
Contributor Author

Sure @fmassa , will just update and get back.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 22, 2020

This PR can be closed as the feature is already implemented (#2394)

@vfdev-5 vfdev-5 closed this Oct 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants