Skip to content

Commit

Permalink
Added bicubic support for interpolation with AA (#3810)
Browse files Browse the repository at this point in the history
* Added support for bicubic mode with AA

* Updated comment in the test
  • Loading branch information
vfdev-5 authored May 13, 2021
1 parent e35793a commit 0fd0f50
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 75 deletions.
13 changes: 11 additions & 2 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ def test_perspective_interpolation_warning(tester):
@pytest.mark.parametrize('device', ["cpu", ])
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('size', [[96, 72], [96, 420], [420, 72]])
@pytest.mark.parametrize('interpolation', [BILINEAR, ])
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC])
def test_resize_antialias(device, dt, size, interpolation, tester):

if dt == torch.float16 and device == "cpu":
Expand Down Expand Up @@ -1051,8 +1051,17 @@ def test_resize_antialias(device, dt, size, interpolation, tester):
tester.approxEqualTensorToPIL(
resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}"
)

accepted_tol = 1.0 + 1e-5
if interpolation == BICUBIC:
# this overall mean value to make the tests pass
# High value is mostly required for test cases with
# downsampling and upsampling where we can not exactly
# match PIL implementation.
accepted_tol = 15.0

tester.approxEqualTensorToPIL(
resized_tensor_f, resized_pil_img, tol=1.0 + 1e-5, agg_method="max",
resized_tensor_f, resized_pil_img, tol=accepted_tol, agg_method="max",
msg=f"{size}, {interpolation}, {dt}"
)

Expand Down
223 changes: 155 additions & 68 deletions torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,66 +141,7 @@ void ti_cpu_upsample_generic_aa(
// Helper structs to use with ti_upsample_generic_Nd_kernel_impl
template <typename index_t, typename scalar_t>
struct HelperInterpBase {
static inline void init_indices_weights(
std::vector<Tensor>& output,
int64_t output_size,
int64_t ndims,
int64_t reshape_dim,
int interp_size) {
auto new_shape = std::vector<int64_t>(ndims, 1);
new_shape[reshape_dim] = output_size;

for (int j = 0; j < interp_size; j++) {
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<scalar_t>())));
}
}
};

template <typename index_t, typename scalar_t>
struct HelperInterpLinear : public HelperInterpBase<index_t, scalar_t> {
static const int interp_size = 2;

static inline std::vector<Tensor> compute_indices_weights(
int64_t input_size,
int64_t output_size,
int64_t stride,
int64_t ndims,
int64_t reshape_dim,
bool align_corners,
const c10::optional<double> opt_scale,
bool antialias,
int& out_interp_size) {
scalar_t scale = area_pixel_compute_scale<scalar_t>(
input_size, output_size, align_corners, opt_scale);
TORCH_INTERNAL_ASSERT(antialias);

return _compute_indices_weights_aa(
input_size,
output_size,
stride,
ndims,
reshape_dim,
align_corners,
scale,
out_interp_size);
}

// taken from
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
// src/libImaging/Resample.c#L20-L29
static inline scalar_t _filter(scalar_t x) {
if (x < 0.0) {
x = -x;
}
if (x < 1.0) {
return 1.0 - x;
}
return 0.0;
}

template <typename filter_fn_t>
static inline std::vector<Tensor> _compute_indices_weights_aa(
int64_t input_size,
int64_t output_size,
Expand All @@ -209,14 +150,15 @@ struct HelperInterpLinear : public HelperInterpBase<index_t, scalar_t> {
int64_t reshape_dim,
bool align_corners,
scalar_t scale,
int& out_interp_size) {
int interp_size = HelperInterpLinear<index_t, scalar_t>::interp_size;
int& in_out_interp_size,
filter_fn_t filter_fn) {
int interp_size = in_out_interp_size;
scalar_t support =
(scale >= 1.0) ? (interp_size / 2) * scale : interp_size / 2 * 1.0;
(scale >= 1.0) ? (interp_size * 0.5) * scale : interp_size * 0.5;
interp_size = (int)ceilf(support) * 2 + 1;

// return interp_size
out_interp_size = interp_size;
in_out_interp_size = interp_size;

std::vector<Tensor> output;
auto new_shape = std::vector<int64_t>(ndims, 1);
Expand Down Expand Up @@ -269,7 +211,7 @@ struct HelperInterpLinear : public HelperInterpBase<index_t, scalar_t> {

total_w = 0.0;
for (j = 0; j < xmax; j++) {
scalar_t w = _filter((j + xmin - center + 0.5) * invscale);
scalar_t w = filter_fn((j + xmin - center + 0.5) * invscale);
wt_ptr[i * interp_size + j] = w;
total_w += w;
}
Expand All @@ -287,6 +229,102 @@ struct HelperInterpLinear : public HelperInterpBase<index_t, scalar_t> {
}
};

template <typename index_t, typename scalar_t>
struct HelperInterpLinear : public HelperInterpBase<index_t, scalar_t> {
static const int interp_size = 2;

// taken from
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
// src/libImaging/Resample.c#L20-L29
static inline scalar_t _filter(scalar_t x) {
if (x < 0.0) {
x = -x;
}
if (x < 1.0) {
return 1.0 - x;
}
return 0.0;
}

static inline std::vector<Tensor> compute_indices_weights(
int64_t input_size,
int64_t output_size,
int64_t stride,
int64_t ndims,
int64_t reshape_dim,
bool align_corners,
const c10::optional<double> opt_scale,
bool antialias,
int& out_interp_size) {
TORCH_INTERNAL_ASSERT(antialias);
scalar_t scale = area_pixel_compute_scale<scalar_t>(
input_size, output_size, align_corners, opt_scale);

out_interp_size = HelperInterpLinear<index_t, scalar_t>::interp_size;
return HelperInterpLinear<index_t, scalar_t>::_compute_indices_weights_aa(
input_size,
output_size,
stride,
ndims,
reshape_dim,
align_corners,
scale,
out_interp_size,
_filter);
}
};

template <typename index_t, typename scalar_t>
struct HelperInterpCubic : public HelperInterpBase<index_t, scalar_t> {
static const int interp_size = 4;

static inline std::vector<Tensor> compute_indices_weights(
int64_t input_size,
int64_t output_size,
int64_t stride,
int64_t ndims,
int64_t reshape_dim,
bool align_corners,
const c10::optional<double> opt_scale,
bool antialias,
int& out_interp_size) {
TORCH_INTERNAL_ASSERT(antialias);
scalar_t scale = area_pixel_compute_scale<scalar_t>(
input_size, output_size, align_corners, opt_scale);

out_interp_size = HelperInterpCubic<index_t, scalar_t>::interp_size;
return HelperInterpCubic<index_t, scalar_t>::_compute_indices_weights_aa(
input_size,
output_size,
stride,
ndims,
reshape_dim,
align_corners,
scale,
out_interp_size,
_filter);
}

// taken from
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
// src/libImaging/Resample.c#L46-L62
static inline scalar_t _filter(scalar_t x) {
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
#define a -0.5
if (x < 0.0) {
x = -x;
}
if (x < 1.0) {
return ((a + 2.0) * x - (a + 3.0)) * x * x + 1;
}
if (x < 2.0) {
return (((x - 5) * x + 8) * x - 4) * a;
}
return 0.0;
#undef a
}
};

template <
typename index_t,
int out_ndims,
Expand Down Expand Up @@ -396,16 +434,15 @@ void ti_separable_upsample_generic_Nd_kernel_impl(
index_t,
out_ndims,
scale_t,
HelperInterpLinear>(
F>(
temp_output, temp_input, interp_dim, align_corners, scales, antialias);
temp_input = temp_output;
}
_ti_separable_upsample_generic_Nd_kernel_impl_single_dim<
index_t,
out_ndims,
scale_t,
HelperInterpLinear>(
output, temp_input, 2, align_corners, scales, antialias);
F>(output, temp_input, 2, align_corners, scales, antialias);
}

void _ti_upsample_bilinear2d_kernel_impl(
Expand All @@ -423,6 +460,21 @@ void _ti_upsample_bilinear2d_kernel_impl(
output, input, align_corners, {scales_h, scales_w}, antialias);
}

void _ti_upsample_bicubic2d_kernel_impl(
Tensor& output,
const Tensor& input,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
bool antialias) {
ti_separable_upsample_generic_Nd_kernel_impl<
int64_t,
2,
scale_t,
HelperInterpCubic>(
output, input, align_corners, {scales_h, scales_w}, antialias);
}

} // namespace internal_upsample
} // namespace native
} // namespace at
Expand Down Expand Up @@ -463,6 +515,37 @@ at::Tensor interpolate_linear_aa_forward_kernel(
return output;
}

at::Tensor interpolate_bicubic_aa_forward_kernel(
const at::Tensor& input,
at::IntArrayRef output_size,
bool align_corners) {
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");

c10::optional<c10::ArrayRef<double>> scale_factors = {};

// Copied from UpSampleBilinear2d.cpp
auto output = at::empty({0}, input.options());
auto osize = at::native::upsample::compute_output_size(
input.sizes(), output_size, scale_factors);
auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0);
auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1);
auto full_output_size =
at::native::upsample_2d_common_check(input.sizes(), osize);

// Allow for empty batch size but not other dimensions
TORCH_CHECK(
input.numel() != 0 ||
c10::multiply_integers(
input.sizes().begin() + 1, input.sizes().end()),
"Non-empty 4D data tensor expected but got a tensor with sizes ",
input.sizes());

output.resize_(full_output_size, input.suggest_memory_format());
at::native::internal_upsample::_ti_upsample_bicubic2d_kernel_impl(
output, input, align_corners, scale_h, scale_w, /*antialias=*/true);
return output;
}

// TODO: Implement backward function
// at::Tensor interpolate_linear_aa_backward_kernel(
// const at::Tensor& grad) {
Expand All @@ -475,6 +558,10 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa"),
TORCH_FN(interpolate_linear_aa_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic_aa"),
TORCH_FN(interpolate_bicubic_aa_forward_kernel));

// TODO: Implement backward function
// m.impl(
// TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa_backward"),
Expand Down
16 changes: 15 additions & 1 deletion torchvision/csrc/ops/interpolate_aa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,23 @@ at::Tensor interpolate_linear_aa(
{
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::interpolate_linear_aa", "")
.findSchemaOrThrow("torchvision::_interpolate_linear_aa", "")
.typed<decltype(interpolate_linear_aa)>();
return op.call(input, output_size, align_corners);
}

at::Tensor interpolate_bicubic_aa(
const at::Tensor& input, // Input image
at::IntArrayRef output_size, // Output image size
bool align_corners) // The flag to align corners
{
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_interpolate_bicubic_aa", "")
.typed<decltype(_interpolate_bicubic_aa)>();
return op.call(input, output_size, align_corners);
}

namespace detail {

// TODO: Implement backward function
Expand All @@ -33,6 +45,8 @@ namespace detail {
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_interpolate_linear_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_interpolate_bicubic_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor"));
// TODO: Implement backward function
// m.def(TORCH_SELECTIVE_SCHEMA(
// "torchvision::_interpolate_linear_aa_backward(Tensor grad, Tensor rois,
Expand Down
5 changes: 5 additions & 0 deletions torchvision/csrc/ops/interpolate_aa.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ VISION_API at::Tensor _interpolate_linear_aa(
at::IntArrayRef output_size,
bool align_corners = false);

VISION_API at::Tensor _interpolate_bicubic_aa(
const at::Tensor& input,
at::IntArrayRef output_size,
bool align_corners = false);

namespace detail {

// TODO: Implement backward function
Expand Down
10 changes: 6 additions & 4 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,8 @@ def resize(
if antialias is None:
antialias = False

if antialias and interpolation not in ["bilinear", ]:
raise ValueError("Antialias option is supported for bilinear interpolation mode only")
if antialias and interpolation not in ["bilinear", "bicubic"]:
raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")

w, h = _get_image_size(img)

Expand Down Expand Up @@ -537,8 +537,10 @@ def resize(
align_corners = False if interpolation in ["bilinear", "bicubic"] else None

if antialias:
# Apply antialias for donwsampling on both dims
img = torch.ops.torchvision._interpolate_linear_aa(img, [new_h, new_w], align_corners=False)
if interpolation == "bilinear":
img = torch.ops.torchvision._interpolate_linear_aa(img, [new_h, new_w], align_corners=False)
elif interpolation == "bicubic":
img = torch.ops.torchvision._interpolate_bicubic_aa(img, [new_h, new_w], align_corners=False)
else:
img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners)

Expand Down

0 comments on commit 0fd0f50

Please sign in to comment.