-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Scriptable Resize Added #1666
Scriptable Resize Added #1666
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1666 +/- ##
==========================================
- Coverage 66.02% 0% -66.03%
==========================================
Files 92 92
Lines 7331 7348 +17
Branches 1107 1111 +4
==========================================
- Hits 4840 0 -4840
- Misses 2175 7348 +5173
+ Partials 316 0 -316
Continue to review full report at Codecov.
|
@fmassa could you review this please. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay in reviewing.
I have some comments, let me know what you think.
Also, there are some things I'm not yet sure about, so it might need some discussion
else: | ||
out_img = Fn.interpolate(img.unsqueeze(0), size=size, mode=interpolation) | ||
|
||
return(out_img.clamp(min=0, max=255).squeeze(0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we should be clamping the input to 0-255 in the function. Although this makes it match the behavior of the Pillow implementation, this only works if the input is in 0-255, which is rarely the case for floating point values.
My take is that we should remove the clamp in this function.
Also, can you remove the parenthesis and add a space after the return?
if w < h: | ||
ow = size | ||
oh = int(size * h / w) | ||
out_img = Fn.interpolate(img.unsqueeze(0), size=(oh, ow), mode=interpolation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes that the input is not batched (but it could potentially be, like for videos).
Instead, I think we should unsqueeze only if it's a 3d tensor, and in this case we should squeeze back.
ow = int(size * w / h) | ||
out_img = Fn.interpolate(img.unsqueeze(0), size=(oh, ow), mode=interpolation) | ||
else: | ||
out_img = Fn.interpolate(img.unsqueeze(0), size=size, mode=interpolation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine as is, but I think we can simplify a bit the code so that we only need to call interpolate
in a single codepath.
A solution would be to create the size
in the int
case to be a tuple. Something like
if isinstance(size, int):
w, h = img.shape[2], img.shape[1]
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
size = (oh, ow)
else:
oh = size
ow = int(size * w / h)
size = (oh, ow)
out_img = Fn.interpolate(..., size=size, mode=interpolation)
@@ -219,3 +220,40 @@ def ten_crop(img, size, vertical_flip=False): | |||
def _blend(img1, img2, ratio): | |||
bound = 1 if img1.dtype.is_floating_point else 255 | |||
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) | |||
|
|||
|
|||
def resize(img, size, interpolation="bilinear"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is something that I still need to figure out: should we use the old PIL-based interpolation
with int, or the torch-based one? I'm not sure.
@@ -138,6 +138,36 @@ def test_ten_crop(self): | |||
(transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8))) | |||
self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) | |||
|
|||
def test_resize(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add a test for scriptability?
@fmassa, I am getting error while scripting nn.functional.interpolate `RuntimeError: isIntList() INTERNAL ASSERT FAILED at /pytorch/aten/src/ATen/core/ivalue_inl.h:544, please report a bug to PyTorch. Expected IntList but got DoubleList @torch.jit.script |
@surgan12 you might just need to specify that |
@fmassa I am using 1.3.0. |
having both size and |
Sure @fmassa , will just update and get back. |
This PR can be closed as the feature is already implemented (#2394) |
This PR adds torchscriptable resize as a part of #1375
cc : @fmassa