Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SymIntify roi_align #7448

Merged
merged 9 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from modulefinder import Module

import torch
from torchvision import datasets, io, models, ops, transforms, utils
from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils

from .extension import _HAS_OPS

Expand Down
48 changes: 48 additions & 0 deletions torchvision/_meta_registrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import torch.library

# Ensure that torch.ops.torchvision is visible
import torchvision.extension # noqa: F401

from torch._prims_common import check

_meta_lib = torch.library.Library("torchvision", "IMPL", "Meta")

vision = torch.ops.torchvision


def register_meta(op):
def wrapper(fn):
_meta_lib.impl(op, fn)
return fn

return wrapper


@register_meta(vision.roi_align.default)
def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
check(
input.dtype == rois.dtype,
lambda: (
"Expected tensor for input to have the same type as tensor for rois; "
f"but type {input.dtype} does not equal {rois.dtype}"
),
)
num_rois = rois.size(0)
_, channels, height, width = input.size()
return input.new_empty((num_rois, channels, pooled_height, pooled_width))


@register_meta(vision._roi_align_backward.default)
def meta_roi_align_backward(
grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
):
check(
grad.dtype == rois.dtype,
lambda: (
"Expected tensor for grad to have the same type as tensor for rois; "
f"but type {grad.dtype} does not equal {rois.dtype}"
),
)
return grad.new_empty((batch_size, channels, height, width))
54 changes: 27 additions & 27 deletions torchvision/csrc/ops/autograd/roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio,
bool aligned) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["aligned"] = aligned;
ctx->saved_data["input_shape"] = input.sizes();
ctx->saved_data["input_shape"] = input.sym_sizes();
ctx->save_for_backward({rois});
at::AutoDispatchBelowADInplaceOrView g;
auto result = roi_align(
auto result = roi_align_symint(
input,
rois,
spatial_scale,
Expand All @@ -44,17 +44,17 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = detail::_roi_align_backward(
auto input_shape = ctx->saved_data["input_shape"].toList();
auto grad_in = detail::_roi_align_backward_symint(
grad_output[0],
rois,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
ctx->saved_data["pooled_height"].toSymInt(),
ctx->saved_data["pooled_width"].toSymInt(),
input_shape[0].get().toSymInt(),
input_shape[1].get().toSymInt(),
input_shape[2].get().toSymInt(),
input_shape[3].get().toSymInt(),
ctx->saved_data["sampling_ratio"].toInt(),
ctx->saved_data["aligned"].toBool());
return {
Expand All @@ -77,16 +77,16 @@ class ROIAlignBackwardFunction
const torch::autograd::Variable& grad,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width,
int64_t sampling_ratio,
bool aligned) {
at::AutoDispatchBelowADInplaceOrView g;
auto result = detail::_roi_align_backward(
auto result = detail::_roi_align_backward_symint(
grad,
rois,
spatial_scale,
Expand All @@ -112,8 +112,8 @@ at::Tensor roi_align_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio,
bool aligned) {
return ROIAlignFunction::apply(
Expand All @@ -130,12 +130,12 @@ at::Tensor roi_align_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width,
int64_t sampling_ratio,
bool aligned) {
return ROIAlignBackwardFunction::apply(
Expand Down
59 changes: 57 additions & 2 deletions torchvision/csrc/ops/roi_align.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,31 @@ at::Tensor roi_align(
aligned);
}

at::Tensor roi_align_symint(
const at::Tensor& input, // Input feature map.
const at::Tensor& rois, // List of ROIs to pool over.
double spatial_scale, // The scale of the image features. ROIs will be
// scaled to this.
c10::SymInt pooled_height, // The height of the pooled feature map.
c10::SymInt pooled_width, // The width of the pooled feature
int64_t sampling_ratio, // The number of points to sample in each bin
bool aligned) // The flag for pixel shift
// along each axis.
{
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align");
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::roi_align", "")
.typed<decltype(roi_align_symint)>();
return op.call(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned);
}

namespace detail {

at::Tensor _roi_align_backward(
Expand Down Expand Up @@ -64,13 +89,43 @@ at::Tensor _roi_align_backward(
aligned);
}

at::Tensor _roi_align_backward_symint(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width,
int64_t sampling_ratio,
bool aligned) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_roi_align_backward", "")
.typed<decltype(_roi_align_backward_symint)>();
return op.call(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio,
aligned);
}

} // namespace detail

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor"));
"torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, bool aligned) -> Tensor"));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're now only registering the SymInt signature, do we still need to keep the pre-existing roi_align definitions/declarations that are using int64_t?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In core, we had to keep the old signatures because the C++ API is public API, and the SymInt signature is not exactly interchangeable with the int signature (as it can affect what implicit conversions are specified). If your old signatures are not public API, we can remove them too, but I'm guessing they are public-ish? In any case, this is the most conservative change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm open to other strategies, as we will have to do this for every function we SymInt'ify which is going to be a bit of a pain.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks.

I'm guessing they are public-ish

Yeah... They're in the "we want them to be private but we don't know who's using them in the wild" category.

Let's keep them in for now and perhaps reconsider if this becomes too much of a mess.

m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor"));
"torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width, int sampling_ratio, bool aligned) -> Tensor"));
}

} // namespace ops
Expand Down
22 changes: 22 additions & 0 deletions torchvision/csrc/ops/roi_align.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ VISION_API at::Tensor roi_align(
int64_t sampling_ratio,
bool aligned);

VISION_API at::Tensor roi_align_symint(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio,
bool aligned);

namespace detail {

at::Tensor _roi_align_backward(
Expand All @@ -30,6 +39,19 @@ at::Tensor _roi_align_backward(
int64_t sampling_ratio,
bool aligned);

at::Tensor _roi_align_backward_symint(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width,
int64_t sampling_ratio,
bool aligned);

} // namespace detail

} // namespace ops
Expand Down