Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/2] Added backward pass on CPU for interpolation with anti-alias option #4208

Merged
merged 6 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
import itertools
import os
import colorsys
Expand Down Expand Up @@ -578,6 +579,52 @@ def test_assert_resize_antialias(interpolation):
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True)


@pytest.mark.parametrize('dt', [torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('size', [[10, 7], [10, 42], [42, 7]])
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC])
def test_interpolate_antialias_backward(dt, size, interpolation):

# temporarily hard-code device as CPU, CUDA support will be done later
device = "cpu"

if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
return

torch.manual_seed(12)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, should we set the seed of the tests in a more automated way via pytest?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine for now as this is what we're doing almost everywhere else
Once we're finished with pytest porting I'll look into ways to improve the RNG handling in our tests.
One thing I'm wondering is: does torch.manual_seed(12) leak the rng for the rest of the tests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does torch.manual_seed(12) leak the rng for the rest of the tests?

I think yes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's pretty bad :)

but yeah, it's OK for now to use the old pattern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it certainly does leak it and it's bitten us quite a few times in the past.

if interpolation == BILINEAR:
forward_op = torch.ops.torchvision._interpolate_bilinear2d_aa
backward_op = torch.ops.torchvision._interpolate_bilinear2d_aa_backward
elif interpolation == BICUBIC:
forward_op = torch.ops.torchvision._interpolate_bicubic2d_aa
backward_op = torch.ops.torchvision._interpolate_bicubic2d_aa_backward

class F(torch.autograd.Function):

@staticmethod
def forward(ctx, i):
result = forward_op(i, size, False)
ctx.save_for_backward(i, result)
return result

@staticmethod
def backward(ctx, grad_output):
i, result = ctx.saved_tensors
ishape = i.shape
oshape = result.shape[2:]
return backward_op(grad_output, oshape, ishape, False)
Comment on lines +602 to +615
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for now. I suppose the next step would be to move those functions to PyTorch so that we nave native autograd support?


x = (
torch.rand(1, 32, 29, 3, dtype=torch.double, device=device).permute(0, 3, 1, 2).requires_grad_(True),
)
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)

x = (
torch.rand(1, 3, 32, 29, dtype=torch.double, device=device, requires_grad=True),
)
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)


def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"):

script_fn = torch.jit.script(fn)
Expand Down
Loading