-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[1/2] Added backward pass on CPU for interpolation with anti-alias option #4208
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
3d9240d
WIP on backward op interpolation with AA
vfdev-5 206c2be
Removed cuda tests and reformat cpp code
3529515
Merge branch 'master' into add-backward-interp-aa
vfdev-5 5375e8c
Fixed clang wrong formatting
1ed1a16
Added channels last test case
d3b2d30
Merge branch 'master' into add-backward-interp-aa
fmassa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from functools import partial | ||
import itertools | ||
import os | ||
import colorsys | ||
|
@@ -578,6 +579,52 @@ def test_assert_resize_antialias(interpolation): | |
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True) | ||
|
||
|
||
@pytest.mark.parametrize('dt', [torch.float32, torch.float64, torch.float16]) | ||
@pytest.mark.parametrize('size', [[10, 7], [10, 42], [42, 7]]) | ||
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC]) | ||
def test_interpolate_antialias_backward(dt, size, interpolation): | ||
|
||
# temporarily hard-code device as CPU, CUDA support will be done later | ||
device = "cpu" | ||
|
||
if dt == torch.float16 and device == "cpu": | ||
# skip float16 on CPU case | ||
return | ||
|
||
torch.manual_seed(12) | ||
if interpolation == BILINEAR: | ||
forward_op = torch.ops.torchvision._interpolate_bilinear2d_aa | ||
backward_op = torch.ops.torchvision._interpolate_bilinear2d_aa_backward | ||
elif interpolation == BICUBIC: | ||
forward_op = torch.ops.torchvision._interpolate_bicubic2d_aa | ||
backward_op = torch.ops.torchvision._interpolate_bicubic2d_aa_backward | ||
|
||
class F(torch.autograd.Function): | ||
|
||
@staticmethod | ||
def forward(ctx, i): | ||
result = forward_op(i, size, False) | ||
ctx.save_for_backward(i, result) | ||
return result | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
i, result = ctx.saved_tensors | ||
ishape = i.shape | ||
oshape = result.shape[2:] | ||
return backward_op(grad_output, oshape, ishape, False) | ||
Comment on lines
+602
to
+615
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fine for now. I suppose the next step would be to move those functions to PyTorch so that we nave native autograd support? |
||
|
||
x = ( | ||
torch.rand(1, 32, 29, 3, dtype=torch.double, device=device).permute(0, 3, 1, 2).requires_grad_(True), | ||
) | ||
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) | ||
|
||
x = ( | ||
torch.rand(1, 3, 32, 29, dtype=torch.double, device=device, requires_grad=True), | ||
) | ||
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) | ||
|
||
|
||
def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"): | ||
|
||
script_fn = torch.jit.script(fn) | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, should we set the seed of the tests in a more automated way via pytest?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's fine for now as this is what we're doing almost everywhere else
Once we're finished with pytest porting I'll look into ways to improve the RNG handling in our tests.
One thing I'm wondering is: does
torch.manual_seed(12)
leak the rng for the rest of the tests?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's pretty bad :)
but yeah, it's OK for now to use the old pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it certainly does leak it and it's bitten us quite a few times in the past.