From 32e7abf723d22493076d9fc4d31149af3ee25707 Mon Sep 17 00:00:00 2001
From: Francisco Massa <fmassa@fb.com>
Date: Mon, 14 Dec 2020 16:50:53 -0800
Subject: [PATCH] Move autograd implementations on separate files. (#3154)

Reviewed By: datumbox

Differential Revision: D25531038

fbshipit-source-id: 481434d15d0709417b3c36ff2b13e10a7994dff2
---
 CMakeLists.txt                                |   3 +-
 setup.py                                      |   3 +-
 .../ops/autograd/deform_conv2d_kernel.cpp     | 262 ++++++++++++++++++
 .../csrc/ops/autograd/ps_roi_align_kernel.cpp | 162 +++++++++++
 .../csrc/ops/autograd/ps_roi_pool_kernel.cpp  | 147 ++++++++++
 .../csrc/ops/autograd/roi_align_kernel.cpp    | 162 +++++++++++
 .../csrc/ops/autograd/roi_pool_kernel.cpp     | 147 ++++++++++
 torchvision/csrc/ops/deform_conv2d.cpp        | 257 +----------------
 torchvision/csrc/ops/deform_conv2d.h          |  23 ++
 torchvision/csrc/ops/nms.cpp                  |   1 -
 torchvision/csrc/ops/ps_roi_align.cpp         | 157 +----------
 torchvision/csrc/ops/ps_roi_align.h           |  17 ++
 torchvision/csrc/ops/ps_roi_pool.cpp          | 142 +---------
 torchvision/csrc/ops/ps_roi_pool.h            |  16 ++
 torchvision/csrc/ops/roi_align.cpp            | 157 +----------
 torchvision/csrc/ops/roi_align.h              |  17 ++
 torchvision/csrc/ops/roi_pool.cpp             | 142 +---------
 torchvision/csrc/ops/roi_pool.h               |  16 ++
 18 files changed, 993 insertions(+), 838 deletions(-)
 create mode 100644 torchvision/csrc/ops/autograd/deform_conv2d_kernel.cpp
 create mode 100644 torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp
 create mode 100644 torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp
 create mode 100644 torchvision/csrc/ops/autograd/roi_align_kernel.cpp
 create mode 100644 torchvision/csrc/ops/autograd/roi_pool_kernel.cpp

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 329feb6c6d3..e10a8a3d161 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -52,7 +52,8 @@ include(GNUInstallDirs)
 include(CMakePackageConfigHelpers)
 
 set(TVCPP torchvision/csrc)
-list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops ${TVCPP}/ops/cpu)
+list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops
+  ${TVCPP}/ops/autograd ${TVCPP}/ops/cpu)
 if(WITH_CUDA)
     list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
 endif()
diff --git a/setup.py b/setup.py
index 23cc53268c4..da154ec14b2 100644
--- a/setup.py
+++ b/setup.py
@@ -136,7 +136,8 @@ def get_extensions():
 
     main_file = glob.glob(os.path.join(extensions_dir, '*.cpp')) + glob.glob(os.path.join(extensions_dir, 'ops',
                                                                                           '*.cpp'))
-    source_cpu = glob.glob(os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp'))
+    source_cpu = glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) + glob.glob(
+        os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp'))
 
     is_rocm_pytorch = False
     if torch.__version__ >= '1.5':
diff --git a/torchvision/csrc/ops/autograd/deform_conv2d_kernel.cpp b/torchvision/csrc/ops/autograd/deform_conv2d_kernel.cpp
new file mode 100644
index 00000000000..54d1512f3e8
--- /dev/null
+++ b/torchvision/csrc/ops/autograd/deform_conv2d_kernel.cpp
@@ -0,0 +1,262 @@
+#include "../deform_conv2d.h"
+
+#include <torch/autograd.h>
+#include <torch/types.h>
+
+namespace vision {
+namespace ops {
+
+namespace {
+
+class DeformConv2dFunction
+    : public torch::autograd::Function<DeformConv2dFunction> {
+ public:
+  static torch::autograd::variable_list forward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::Variable& input,
+      const torch::autograd::Variable& weight,
+      const torch::autograd::Variable& offset,
+      const torch::autograd::Variable& mask,
+      const torch::autograd::Variable& bias,
+      int64_t stride_h,
+      int64_t stride_w,
+      int64_t pad_h,
+      int64_t pad_w,
+      int64_t dilation_h,
+      int64_t dilation_w,
+      int64_t groups,
+      int64_t offset_groups,
+      bool use_mask) {
+    at::AutoNonVariableTypeMode g;
+    auto output = deform_conv2d(
+        input,
+        weight,
+        offset,
+        mask,
+        bias,
+        stride_h,
+        stride_w,
+        pad_h,
+        pad_w,
+        dilation_h,
+        dilation_w,
+        groups,
+        offset_groups,
+        use_mask);
+
+    ctx->save_for_backward({input, weight, offset, mask, bias});
+    ctx->saved_data["stride_h"] = stride_h;
+    ctx->saved_data["stride_w"] = stride_w;
+    ctx->saved_data["pad_h"] = pad_h;
+    ctx->saved_data["pad_w"] = pad_w;
+    ctx->saved_data["dilation_h"] = dilation_h;
+    ctx->saved_data["dilation_w"] = dilation_w;
+    ctx->saved_data["groups"] = groups;
+    ctx->saved_data["offset_groups"] = offset_groups;
+    ctx->saved_data["use_mask"] = use_mask;
+
+    return {
+        output,
+    };
+  }
+
+  static torch::autograd::variable_list backward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::variable_list& grad_output) {
+    auto saved = ctx->get_saved_variables();
+    auto input = saved[0];
+    auto weight = saved[1];
+    auto offset = saved[2];
+    auto mask = saved[3];
+    auto bias = saved[4];
+
+    auto stride_h = ctx->saved_data["stride_h"].toInt();
+    auto stride_w = ctx->saved_data["stride_w"].toInt();
+    auto pad_h = ctx->saved_data["pad_h"].toInt();
+    auto pad_w = ctx->saved_data["pad_w"].toInt();
+    auto dilation_h = ctx->saved_data["dilation_h"].toInt();
+    auto dilation_w = ctx->saved_data["dilation_w"].toInt();
+    auto groups = ctx->saved_data["groups"].toInt();
+    auto offset_groups = ctx->saved_data["offset_groups"].toInt();
+    auto use_mask = ctx->saved_data["use_mask"].toBool();
+
+    auto grads = detail::_deform_conv2d_backward(
+        grad_output[0],
+        input,
+        weight,
+        offset,
+        mask,
+        bias,
+        stride_h,
+        stride_w,
+        pad_h,
+        pad_w,
+        dilation_h,
+        dilation_w,
+        groups,
+        offset_groups,
+        use_mask);
+    auto grad_input = std::get<0>(grads);
+    auto grad_weight = std::get<1>(grads);
+    auto grad_offset = std::get<2>(grads);
+    auto grad_mask = std::get<3>(grads);
+    auto grad_bias = std::get<4>(grads);
+
+    return {
+        grad_input,
+        grad_weight,
+        grad_offset,
+        grad_mask,
+        grad_bias,
+        torch::autograd::Variable(),
+        torch::autograd::Variable(),
+        torch::autograd::Variable(),
+        torch::autograd::Variable(),
+        torch::autograd::Variable(),
+        torch::autograd::Variable(),
+        torch::autograd::Variable(),
+        torch::autograd::Variable(),
+        torch::autograd::Variable(),
+    };
+  }
+};
+
+// TODO: There should be an easier way to do this
+class DeformConv2dBackwardFunction
+    : public torch::autograd::Function<DeformConv2dBackwardFunction> {
+ public:
+  static torch::autograd::variable_list forward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::Variable& grad,
+      const torch::autograd::Variable& input,
+      const torch::autograd::Variable& weight,
+      const torch::autograd::Variable& offset,
+      const torch::autograd::Variable& mask,
+      const torch::autograd::Variable& bias,
+      int64_t stride_h,
+      int64_t stride_w,
+      int64_t pad_h,
+      int64_t pad_w,
+      int64_t dilation_h,
+      int64_t dilation_w,
+      int64_t groups,
+      int64_t offset_groups,
+      bool use_mask) {
+    at::AutoNonVariableTypeMode g;
+    auto result = detail::_deform_conv2d_backward(
+        grad,
+        input,
+        weight,
+        offset,
+        mask,
+        bias,
+        stride_h,
+        stride_w,
+        pad_h,
+        pad_w,
+        dilation_h,
+        dilation_w,
+        groups,
+        offset_groups,
+        use_mask);
+
+    auto grad_input = std::get<0>(result);
+    auto grad_weight = std::get<1>(result);
+    auto grad_offset = std::get<2>(result);
+    auto grad_mask = std::get<3>(result);
+    auto grad_bias = std::get<4>(result);
+
+    return {
+        grad_input,
+        grad_weight,
+        grad_offset,
+        grad_mask,
+        grad_bias,
+    };
+  }
+
+  static torch::autograd::variable_list backward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::variable_list& grad_output) {
+    TORCH_CHECK(0, "double backwards on deform_conv2d not supported");
+  }
+};
+
+at::Tensor deform_conv2d_autograd(
+    const at::Tensor& input,
+    const at::Tensor& weight,
+    const at::Tensor& offset,
+    const at::Tensor& mask,
+    const at::Tensor& bias,
+    int64_t stride_h,
+    int64_t stride_w,
+    int64_t pad_h,
+    int64_t pad_w,
+    int64_t dilation_h,
+    int64_t dilation_w,
+    int64_t groups,
+    int64_t offset_groups,
+    bool use_mask) {
+  return DeformConv2dFunction::apply(
+      input,
+      weight,
+      offset,
+      mask,
+      bias,
+      stride_h,
+      stride_w,
+      pad_h,
+      pad_w,
+      dilation_h,
+      dilation_w,
+      groups,
+      offset_groups,
+      use_mask)[0];
+}
+
+std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
+deform_conv2d_backward_autograd(
+    const at::Tensor& grad,
+    const at::Tensor& input,
+    const at::Tensor& weight,
+    const at::Tensor& offset,
+    const at::Tensor& mask,
+    const at::Tensor& bias,
+    int64_t stride_h,
+    int64_t stride_w,
+    int64_t pad_h,
+    int64_t pad_w,
+    int64_t dilation_h,
+    int64_t dilation_w,
+    int64_t groups,
+    int64_t offset_groups,
+    bool use_mask) {
+  auto result = DeformConv2dBackwardFunction::apply(
+      grad,
+      input,
+      weight,
+      offset,
+      mask,
+      bias,
+      stride_h,
+      stride_w,
+      pad_h,
+      pad_w,
+      dilation_h,
+      dilation_w,
+      groups,
+      offset_groups,
+      use_mask);
+
+  return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
+}
+
+} // namespace
+
+TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
+  m.impl("deform_conv2d", deform_conv2d_autograd);
+  m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
+}
+
+} // namespace ops
+} // namespace vision
diff --git a/torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp b/torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp
new file mode 100644
index 00000000000..7fe215112d1
--- /dev/null
+++ b/torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp
@@ -0,0 +1,162 @@
+#include "../ps_roi_align.h"
+
+#include <torch/autograd.h>
+#include <torch/types.h>
+
+namespace vision {
+namespace ops {
+
+namespace {
+
+class PSROIAlignFunction
+    : public torch::autograd::Function<PSROIAlignFunction> {
+ public:
+  static torch::autograd::variable_list forward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::Variable& input,
+      const torch::autograd::Variable& rois,
+      double spatial_scale,
+      int64_t pooled_height,
+      int64_t pooled_width,
+      int64_t sampling_ratio) {
+    ctx->saved_data["spatial_scale"] = spatial_scale;
+    ctx->saved_data["pooled_height"] = pooled_height;
+    ctx->saved_data["pooled_width"] = pooled_width;
+    ctx->saved_data["sampling_ratio"] = sampling_ratio;
+    ctx->saved_data["input_shape"] = input.sizes();
+    at::AutoNonVariableTypeMode g;
+    auto result = ps_roi_align(
+        input,
+        rois,
+        spatial_scale,
+        pooled_height,
+        pooled_width,
+        sampling_ratio);
+
+    auto output = std::get<0>(result);
+    auto channel_mapping = std::get<1>(result);
+    ctx->save_for_backward({rois, channel_mapping});
+    ctx->mark_non_differentiable({channel_mapping});
+
+    return {output, channel_mapping};
+  }
+
+  static torch::autograd::variable_list backward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::variable_list& grad_output) {
+    // Use data saved in forward
+    auto saved = ctx->get_saved_variables();
+    auto rois = saved[0];
+    auto channel_mapping = saved[1];
+    auto input_shape = ctx->saved_data["input_shape"].toIntList();
+    auto grad_in = detail::_ps_roi_align_backward(
+        grad_output[0],
+        rois,
+        channel_mapping,
+        ctx->saved_data["spatial_scale"].toDouble(),
+        ctx->saved_data["pooled_height"].toInt(),
+        ctx->saved_data["pooled_width"].toInt(),
+        ctx->saved_data["sampling_ratio"].toInt(),
+        input_shape[0],
+        input_shape[1],
+        input_shape[2],
+        input_shape[3]);
+
+    return {grad_in,
+            torch::autograd::Variable(),
+            torch::autograd::Variable(),
+            torch::autograd::Variable(),
+            torch::autograd::Variable(),
+            torch::autograd::Variable()};
+  }
+};
+
+// TODO: There should be an easier way to do this
+class PSROIAlignBackwardFunction
+    : public torch::autograd::Function<PSROIAlignBackwardFunction> {
+ public:
+  static torch::autograd::variable_list forward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::Variable& grad,
+      const torch::autograd::Variable& rois,
+      const torch::autograd::Variable& channel_mapping,
+      double spatial_scale,
+      int64_t pooled_height,
+      int64_t pooled_width,
+      int64_t sampling_ratio,
+      int64_t batch_size,
+      int64_t channels,
+      int64_t height,
+      int64_t width) {
+    at::AutoNonVariableTypeMode g;
+    auto grad_in = detail::_ps_roi_align_backward(
+        grad,
+        rois,
+        channel_mapping,
+        spatial_scale,
+        pooled_height,
+        pooled_width,
+        sampling_ratio,
+        batch_size,
+        channels,
+        height,
+        width);
+
+    return {grad_in};
+  }
+
+  static torch::autograd::variable_list backward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::variable_list& grad_output) {
+    TORCH_CHECK(0, "double backwards on ps_roi_align not supported");
+  }
+};
+
+std::tuple<at::Tensor, at::Tensor> ps_roi_align_autograd(
+    const at::Tensor& input,
+    const at::Tensor& rois,
+    double spatial_scale,
+    int64_t pooled_height,
+    int64_t pooled_width,
+    int64_t sampling_ratio) {
+  auto result = PSROIAlignFunction::apply(
+      input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
+
+  return std::make_tuple(result[0], result[1]);
+}
+
+at::Tensor ps_roi_align_backward_autograd(
+    const at::Tensor& grad,
+    const at::Tensor& rois,
+    const at::Tensor& channel_mapping,
+    double spatial_scale,
+    int64_t pooled_height,
+    int64_t pooled_width,
+    int64_t sampling_ratio,
+    int64_t batch_size,
+    int64_t channels,
+    int64_t height,
+    int64_t width) {
+  return PSROIAlignBackwardFunction::apply(
+      grad,
+      rois,
+      channel_mapping,
+      spatial_scale,
+      pooled_height,
+      pooled_width,
+      sampling_ratio,
+      batch_size,
+      channels,
+      height,
+      width)[0];
+}
+
+} // namespace
+
+TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
+  m.impl("ps_roi_align", ps_roi_align_autograd);
+  m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd);
+}
+
+} // namespace ops
+} // namespace vision
diff --git a/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp b/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp
new file mode 100644
index 00000000000..89891e03c1b
--- /dev/null
+++ b/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp
@@ -0,0 +1,147 @@
+#include "../ps_roi_pool.h"
+
+#include <torch/autograd.h>
+#include <torch/types.h>
+
+namespace vision {
+namespace ops {
+
+namespace {
+
+class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
+ public:
+  static torch::autograd::variable_list forward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::Variable& input,
+      const torch::autograd::Variable& rois,
+      double spatial_scale,
+      int64_t pooled_height,
+      int64_t pooled_width) {
+    ctx->saved_data["spatial_scale"] = spatial_scale;
+    ctx->saved_data["pooled_height"] = pooled_height;
+    ctx->saved_data["pooled_width"] = pooled_width;
+    ctx->saved_data["input_shape"] = input.sizes();
+    at::AutoNonVariableTypeMode g;
+    auto result =
+        ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width);
+
+    auto output = std::get<0>(result);
+    auto channel_mapping = std::get<1>(result);
+    ctx->save_for_backward({rois, channel_mapping});
+    ctx->mark_non_differentiable({channel_mapping});
+
+    return {output, channel_mapping};
+  }
+
+  static torch::autograd::variable_list backward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::variable_list& grad_output) {
+    // Use data saved in forward
+    auto saved = ctx->get_saved_variables();
+    auto rois = saved[0];
+    auto channel_mapping = saved[1];
+    auto input_shape = ctx->saved_data["input_shape"].toIntList();
+    auto grad_in = detail::_ps_roi_pool_backward(
+        grad_output[0],
+        rois,
+        channel_mapping,
+        ctx->saved_data["spatial_scale"].toDouble(),
+        ctx->saved_data["pooled_height"].toInt(),
+        ctx->saved_data["pooled_width"].toInt(),
+        input_shape[0],
+        input_shape[1],
+        input_shape[2],
+        input_shape[3]);
+
+    return {grad_in,
+            torch::autograd::Variable(),
+            torch::autograd::Variable(),
+            torch::autograd::Variable(),
+            torch::autograd::Variable()};
+  }
+};
+
+// TODO: There should be an easier way to do this
+class PSROIPoolBackwardFunction
+    : public torch::autograd::Function<PSROIPoolBackwardFunction> {
+ public:
+  static torch::autograd::variable_list forward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::Variable& grad,
+      const torch::autograd::Variable& rois,
+      const torch::autograd::Variable& channel_mapping,
+      double spatial_scale,
+      int64_t pooled_height,
+      int64_t pooled_width,
+      int64_t batch_size,
+      int64_t channels,
+      int64_t height,
+      int64_t width) {
+    at::AutoNonVariableTypeMode g;
+    auto grad_in = detail::_ps_roi_pool_backward(
+        grad,
+        rois,
+        channel_mapping,
+        spatial_scale,
+        pooled_height,
+        pooled_width,
+        batch_size,
+        channels,
+        height,
+        width);
+
+    return {grad_in};
+  }
+
+  static torch::autograd::variable_list backward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::variable_list& grad_output) {
+    TORCH_CHECK(0, "double backwards on ps_roi_pool not supported");
+  }
+};
+
+std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autograd(
+    const at::Tensor& input,
+    const at::Tensor& rois,
+    double spatial_scale,
+    int64_t pooled_height,
+    int64_t pooled_width) {
+  auto result = PSROIPoolFunction::apply(
+      input, rois, spatial_scale, pooled_height, pooled_width);
+
+  return std::make_tuple(result[0], result[1]);
+}
+
+at::Tensor ps_roi_pool_backward_autograd(
+    const at::Tensor& grad,
+    const at::Tensor& rois,
+    const at::Tensor& channel_mapping,
+    double spatial_scale,
+    int64_t pooled_height,
+    int64_t pooled_width,
+    int64_t batch_size,
+    int64_t channels,
+    int64_t height,
+    int64_t width) {
+  return PSROIPoolBackwardFunction::apply(
+      grad,
+      rois,
+      channel_mapping,
+      spatial_scale,
+      pooled_height,
+      pooled_width,
+      batch_size,
+      channels,
+      height,
+      width)[0];
+}
+
+} // namespace
+
+TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
+  m.impl("ps_roi_pool", ps_roi_pool_autograd);
+  m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_autograd);
+}
+
+} // namespace ops
+} // namespace vision
diff --git a/torchvision/csrc/ops/autograd/roi_align_kernel.cpp b/torchvision/csrc/ops/autograd/roi_align_kernel.cpp
new file mode 100644
index 00000000000..751ee0cd64f
--- /dev/null
+++ b/torchvision/csrc/ops/autograd/roi_align_kernel.cpp
@@ -0,0 +1,162 @@
+#include "../roi_align.h"
+
+#include <torch/autograd.h>
+#include <torch/types.h>
+
+namespace vision {
+namespace ops {
+
+namespace {
+
+class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
+ public:
+  static torch::autograd::variable_list forward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::Variable& input,
+      const torch::autograd::Variable& rois,
+      double spatial_scale,
+      int64_t pooled_height,
+      int64_t pooled_width,
+      int64_t sampling_ratio,
+      bool aligned) {
+    ctx->saved_data["spatial_scale"] = spatial_scale;
+    ctx->saved_data["pooled_height"] = pooled_height;
+    ctx->saved_data["pooled_width"] = pooled_width;
+    ctx->saved_data["sampling_ratio"] = sampling_ratio;
+    ctx->saved_data["aligned"] = aligned;
+    ctx->saved_data["input_shape"] = input.sizes();
+    ctx->save_for_backward({rois});
+    at::AutoNonVariableTypeMode g;
+    auto result = roi_align(
+        input,
+        rois,
+        spatial_scale,
+        pooled_height,
+        pooled_width,
+        sampling_ratio,
+        aligned);
+    return {result};
+  }
+
+  static torch::autograd::variable_list backward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::variable_list& grad_output) {
+    // Use data saved in forward
+    auto saved = ctx->get_saved_variables();
+    auto rois = saved[0];
+    auto input_shape = ctx->saved_data["input_shape"].toIntList();
+    auto grad_in = detail::_roi_align_backward(
+        grad_output[0],
+        rois,
+        ctx->saved_data["spatial_scale"].toDouble(),
+        ctx->saved_data["pooled_height"].toInt(),
+        ctx->saved_data["pooled_width"].toInt(),
+        input_shape[0],
+        input_shape[1],
+        input_shape[2],
+        input_shape[3],
+        ctx->saved_data["sampling_ratio"].toInt(),
+        ctx->saved_data["aligned"].toBool());
+    return {grad_in,
+            torch::autograd::Variable(),
+            torch::autograd::Variable(),
+            torch::autograd::Variable(),
+            torch::autograd::Variable(),
+            torch::autograd::Variable(),
+            torch::autograd::Variable()};
+  }
+};
+
+// TODO: There should be an easier way to do this
+class ROIAlignBackwardFunction
+    : public torch::autograd::Function<ROIAlignBackwardFunction> {
+ public:
+  static torch::autograd::variable_list forward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::Variable& grad,
+      const torch::autograd::Variable& rois,
+      double spatial_scale,
+      int64_t pooled_height,
+      int64_t pooled_width,
+      int64_t batch_size,
+      int64_t channels,
+      int64_t height,
+      int64_t width,
+      int64_t sampling_ratio,
+      bool aligned) {
+    at::AutoNonVariableTypeMode g;
+    auto result = detail::_roi_align_backward(
+        grad,
+        rois,
+        spatial_scale,
+        pooled_height,
+        pooled_width,
+        batch_size,
+        channels,
+        height,
+        width,
+        sampling_ratio,
+        aligned);
+    return {result};
+  }
+
+  static torch::autograd::variable_list backward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::variable_list& grad_output) {
+    TORCH_CHECK(0, "double backwards on roi_align not supported");
+  }
+};
+
+at::Tensor roi_align_autograd(
+    const at::Tensor& input,
+    const at::Tensor& rois,
+    double spatial_scale,
+    int64_t pooled_height,
+    int64_t pooled_width,
+    int64_t sampling_ratio,
+    bool aligned) {
+  return ROIAlignFunction::apply(
+      input,
+      rois,
+      spatial_scale,
+      pooled_height,
+      pooled_width,
+      sampling_ratio,
+      aligned)[0];
+}
+
+at::Tensor roi_align_backward_autograd(
+    const at::Tensor& grad,
+    const at::Tensor& rois,
+    double spatial_scale,
+    int64_t pooled_height,
+    int64_t pooled_width,
+    int64_t batch_size,
+    int64_t channels,
+    int64_t height,
+    int64_t width,
+    int64_t sampling_ratio,
+    bool aligned) {
+  return ROIAlignBackwardFunction::apply(
+      grad,
+      rois,
+      spatial_scale,
+      pooled_height,
+      pooled_width,
+      batch_size,
+      channels,
+      height,
+      width,
+      sampling_ratio,
+      aligned)[0];
+}
+
+} // namespace
+
+TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
+  m.impl("roi_align", roi_align_autograd);
+  m.impl("_roi_align_backward", roi_align_backward_autograd);
+}
+
+} // namespace ops
+} // namespace vision
diff --git a/torchvision/csrc/ops/autograd/roi_pool_kernel.cpp b/torchvision/csrc/ops/autograd/roi_pool_kernel.cpp
new file mode 100644
index 00000000000..e6f9b23ddfb
--- /dev/null
+++ b/torchvision/csrc/ops/autograd/roi_pool_kernel.cpp
@@ -0,0 +1,147 @@
+#include "../roi_pool.h"
+
+#include <torch/autograd.h>
+#include <torch/types.h>
+
+namespace vision {
+namespace ops {
+
+namespace {
+
+class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
+ public:
+  static torch::autograd::variable_list forward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::Variable& input,
+      const torch::autograd::Variable& rois,
+      double spatial_scale,
+      int64_t pooled_height,
+      int64_t pooled_width) {
+    ctx->saved_data["spatial_scale"] = spatial_scale;
+    ctx->saved_data["pooled_height"] = pooled_height;
+    ctx->saved_data["pooled_width"] = pooled_width;
+    ctx->saved_data["input_shape"] = input.sizes();
+    at::AutoNonVariableTypeMode g;
+    auto result =
+        roi_pool(input, rois, spatial_scale, pooled_height, pooled_width);
+
+    auto output = std::get<0>(result);
+    auto argmax = std::get<1>(result);
+    ctx->save_for_backward({rois, argmax});
+    ctx->mark_non_differentiable({argmax});
+
+    return {output, argmax};
+  }
+
+  static torch::autograd::variable_list backward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::variable_list& grad_output) {
+    // Use data saved in forward
+    auto saved = ctx->get_saved_variables();
+    auto rois = saved[0];
+    auto argmax = saved[1];
+    auto input_shape = ctx->saved_data["input_shape"].toIntList();
+    auto grad_in = detail::_roi_pool_backward(
+        grad_output[0],
+        rois,
+        argmax,
+        ctx->saved_data["spatial_scale"].toDouble(),
+        ctx->saved_data["pooled_height"].toInt(),
+        ctx->saved_data["pooled_width"].toInt(),
+        input_shape[0],
+        input_shape[1],
+        input_shape[2],
+        input_shape[3]);
+
+    return {grad_in,
+            torch::autograd::Variable(),
+            torch::autograd::Variable(),
+            torch::autograd::Variable(),
+            torch::autograd::Variable()};
+  }
+};
+
+// TODO: There should be an easier way to do this
+class ROIPoolBackwardFunction
+    : public torch::autograd::Function<ROIPoolBackwardFunction> {
+ public:
+  static torch::autograd::variable_list forward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::Variable& grad,
+      const torch::autograd::Variable& rois,
+      const torch::autograd::Variable& argmax,
+      double spatial_scale,
+      int64_t pooled_height,
+      int64_t pooled_width,
+      int64_t batch_size,
+      int64_t channels,
+      int64_t height,
+      int64_t width) {
+    at::AutoNonVariableTypeMode g;
+    auto grad_in = detail::_roi_pool_backward(
+        grad,
+        rois,
+        argmax,
+        spatial_scale,
+        pooled_height,
+        pooled_width,
+        batch_size,
+        channels,
+        height,
+        width);
+
+    return {grad_in};
+  }
+
+  static torch::autograd::variable_list backward(
+      torch::autograd::AutogradContext* ctx,
+      const torch::autograd::variable_list& grad_output) {
+    TORCH_CHECK(0, "double backwards on roi_pool not supported");
+  }
+};
+
+std::tuple<at::Tensor, at::Tensor> roi_pool_autograd(
+    const at::Tensor& input,
+    const at::Tensor& rois,
+    double spatial_scale,
+    int64_t pooled_height,
+    int64_t pooled_width) {
+  auto result = ROIPoolFunction::apply(
+      input, rois, spatial_scale, pooled_height, pooled_width);
+
+  return std::make_tuple(result[0], result[1]);
+}
+
+at::Tensor roi_pool_backward_autograd(
+    const at::Tensor& grad,
+    const at::Tensor& rois,
+    const at::Tensor& argmax,
+    double spatial_scale,
+    int64_t pooled_height,
+    int64_t pooled_width,
+    int64_t batch_size,
+    int64_t channels,
+    int64_t height,
+    int64_t width) {
+  return ROIPoolBackwardFunction::apply(
+      grad,
+      rois,
+      argmax,
+      spatial_scale,
+      pooled_height,
+      pooled_width,
+      batch_size,
+      channels,
+      height,
+      width)[0];
+}
+
+} // namespace
+
+TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
+  m.impl("roi_pool", roi_pool_autograd);
+  m.impl("_roi_pool_backward", roi_pool_backward_autograd);
+}
+
+} // namespace ops
+} // namespace vision
diff --git a/torchvision/csrc/ops/deform_conv2d.cpp b/torchvision/csrc/ops/deform_conv2d.cpp
index 4b8b2e97668..44870845bd4 100644
--- a/torchvision/csrc/ops/deform_conv2d.cpp
+++ b/torchvision/csrc/ops/deform_conv2d.cpp
@@ -1,6 +1,5 @@
 #include "deform_conv2d.h"
 
-#include <torch/autograd.h>
 #include <torch/types.h>
 
 namespace vision {
@@ -41,6 +40,8 @@ at::Tensor deform_conv2d(
       use_mask);
 }
 
+namespace detail {
+
 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
 _deform_conv2d_backward(
     const at::Tensor& grad,
@@ -80,6 +81,8 @@ _deform_conv2d_backward(
       use_mask);
 }
 
+} // namespace detail
+
 TORCH_LIBRARY_FRAGMENT(torchvision, m) {
   m.def(
       "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor");
@@ -87,257 +90,5 @@ TORCH_LIBRARY_FRAGMENT(torchvision, m) {
       "_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
 }
 
-namespace {
-
-class DeformConv2dFunction
-    : public torch::autograd::Function<DeformConv2dFunction> {
- public:
-  static torch::autograd::variable_list forward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::Variable& input,
-      const torch::autograd::Variable& weight,
-      const torch::autograd::Variable& offset,
-      const torch::autograd::Variable& mask,
-      const torch::autograd::Variable& bias,
-      int64_t stride_h,
-      int64_t stride_w,
-      int64_t pad_h,
-      int64_t pad_w,
-      int64_t dilation_h,
-      int64_t dilation_w,
-      int64_t groups,
-      int64_t offset_groups,
-      bool use_mask) {
-    at::AutoNonVariableTypeMode g;
-    auto output = deform_conv2d(
-        input,
-        weight,
-        offset,
-        mask,
-        bias,
-        stride_h,
-        stride_w,
-        pad_h,
-        pad_w,
-        dilation_h,
-        dilation_w,
-        groups,
-        offset_groups,
-        use_mask);
-
-    ctx->save_for_backward({input, weight, offset, mask, bias});
-    ctx->saved_data["stride_h"] = stride_h;
-    ctx->saved_data["stride_w"] = stride_w;
-    ctx->saved_data["pad_h"] = pad_h;
-    ctx->saved_data["pad_w"] = pad_w;
-    ctx->saved_data["dilation_h"] = dilation_h;
-    ctx->saved_data["dilation_w"] = dilation_w;
-    ctx->saved_data["groups"] = groups;
-    ctx->saved_data["offset_groups"] = offset_groups;
-    ctx->saved_data["use_mask"] = use_mask;
-
-    return {
-        output,
-    };
-  }
-
-  static torch::autograd::variable_list backward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::variable_list& grad_output) {
-    auto saved = ctx->get_saved_variables();
-    auto input = saved[0];
-    auto weight = saved[1];
-    auto offset = saved[2];
-    auto mask = saved[3];
-    auto bias = saved[4];
-
-    auto stride_h = ctx->saved_data["stride_h"].toInt();
-    auto stride_w = ctx->saved_data["stride_w"].toInt();
-    auto pad_h = ctx->saved_data["pad_h"].toInt();
-    auto pad_w = ctx->saved_data["pad_w"].toInt();
-    auto dilation_h = ctx->saved_data["dilation_h"].toInt();
-    auto dilation_w = ctx->saved_data["dilation_w"].toInt();
-    auto groups = ctx->saved_data["groups"].toInt();
-    auto offset_groups = ctx->saved_data["offset_groups"].toInt();
-    auto use_mask = ctx->saved_data["use_mask"].toBool();
-
-    auto grads = _deform_conv2d_backward(
-        grad_output[0],
-        input,
-        weight,
-        offset,
-        mask,
-        bias,
-        stride_h,
-        stride_w,
-        pad_h,
-        pad_w,
-        dilation_h,
-        dilation_w,
-        groups,
-        offset_groups,
-        use_mask);
-    auto grad_input = std::get<0>(grads);
-    auto grad_weight = std::get<1>(grads);
-    auto grad_offset = std::get<2>(grads);
-    auto grad_mask = std::get<3>(grads);
-    auto grad_bias = std::get<4>(grads);
-
-    return {
-        grad_input,
-        grad_weight,
-        grad_offset,
-        grad_mask,
-        grad_bias,
-        torch::autograd::Variable(),
-        torch::autograd::Variable(),
-        torch::autograd::Variable(),
-        torch::autograd::Variable(),
-        torch::autograd::Variable(),
-        torch::autograd::Variable(),
-        torch::autograd::Variable(),
-        torch::autograd::Variable(),
-        torch::autograd::Variable(),
-    };
-  }
-};
-
-// TODO: There should be an easier way to do this
-class DeformConv2dBackwardFunction
-    : public torch::autograd::Function<DeformConv2dBackwardFunction> {
- public:
-  static torch::autograd::variable_list forward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::Variable& grad,
-      const torch::autograd::Variable& input,
-      const torch::autograd::Variable& weight,
-      const torch::autograd::Variable& offset,
-      const torch::autograd::Variable& mask,
-      const torch::autograd::Variable& bias,
-      int64_t stride_h,
-      int64_t stride_w,
-      int64_t pad_h,
-      int64_t pad_w,
-      int64_t dilation_h,
-      int64_t dilation_w,
-      int64_t groups,
-      int64_t offset_groups,
-      bool use_mask) {
-    at::AutoNonVariableTypeMode g;
-    auto result = _deform_conv2d_backward(
-        grad,
-        input,
-        weight,
-        offset,
-        mask,
-        bias,
-        stride_h,
-        stride_w,
-        pad_h,
-        pad_w,
-        dilation_h,
-        dilation_w,
-        groups,
-        offset_groups,
-        use_mask);
-
-    auto grad_input = std::get<0>(result);
-    auto grad_weight = std::get<1>(result);
-    auto grad_offset = std::get<2>(result);
-    auto grad_mask = std::get<3>(result);
-    auto grad_bias = std::get<4>(result);
-
-    return {
-        grad_input,
-        grad_weight,
-        grad_offset,
-        grad_mask,
-        grad_bias,
-    };
-  }
-
-  static torch::autograd::variable_list backward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::variable_list& grad_output) {
-    TORCH_CHECK(0, "double backwards on deform_conv2d not supported");
-  }
-};
-
-at::Tensor deform_conv2d_autograd(
-    const at::Tensor& input,
-    const at::Tensor& weight,
-    const at::Tensor& offset,
-    const at::Tensor& mask,
-    const at::Tensor& bias,
-    int64_t stride_h,
-    int64_t stride_w,
-    int64_t pad_h,
-    int64_t pad_w,
-    int64_t dilation_h,
-    int64_t dilation_w,
-    int64_t groups,
-    int64_t offset_groups,
-    bool use_mask) {
-  return DeformConv2dFunction::apply(
-      input,
-      weight,
-      offset,
-      mask,
-      bias,
-      stride_h,
-      stride_w,
-      pad_h,
-      pad_w,
-      dilation_h,
-      dilation_w,
-      groups,
-      offset_groups,
-      use_mask)[0];
-}
-
-std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
-deform_conv2d_backward_autograd(
-    const at::Tensor& grad,
-    const at::Tensor& input,
-    const at::Tensor& weight,
-    const at::Tensor& offset,
-    const at::Tensor& mask,
-    const at::Tensor& bias,
-    int64_t stride_h,
-    int64_t stride_w,
-    int64_t pad_h,
-    int64_t pad_w,
-    int64_t dilation_h,
-    int64_t dilation_w,
-    int64_t groups,
-    int64_t offset_groups,
-    bool use_mask) {
-  auto result = DeformConv2dBackwardFunction::apply(
-      grad,
-      input,
-      weight,
-      offset,
-      mask,
-      bias,
-      stride_h,
-      stride_w,
-      pad_h,
-      pad_w,
-      dilation_h,
-      dilation_w,
-      groups,
-      offset_groups,
-      use_mask);
-
-  return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
-}
-
-} // namespace
-
-TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
-  m.impl("deform_conv2d", deform_conv2d_autograd);
-  m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
-}
-
 } // namespace ops
 } // namespace vision
diff --git a/torchvision/csrc/ops/deform_conv2d.h b/torchvision/csrc/ops/deform_conv2d.h
index e94636fb280..a72d8ddde9c 100644
--- a/torchvision/csrc/ops/deform_conv2d.h
+++ b/torchvision/csrc/ops/deform_conv2d.h
@@ -22,5 +22,28 @@ VISION_API at::Tensor deform_conv2d(
     int64_t offset_groups,
     bool use_mask);
 
+namespace detail {
+
+VISION_API
+std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
+_deform_conv2d_backward(
+    const at::Tensor& grad,
+    const at::Tensor& input,
+    const at::Tensor& weight,
+    const at::Tensor& offset,
+    const at::Tensor& mask,
+    const at::Tensor& bias,
+    int64_t stride_h,
+    int64_t stride_w,
+    int64_t pad_h,
+    int64_t pad_w,
+    int64_t dilation_h,
+    int64_t dilation_w,
+    int64_t groups,
+    int64_t offset_groups,
+    bool use_mask);
+
+} // namespace detail
+
 } // namespace ops
 } // namespace vision
diff --git a/torchvision/csrc/ops/nms.cpp b/torchvision/csrc/ops/nms.cpp
index 3655a8d00df..8c2455f1142 100644
--- a/torchvision/csrc/ops/nms.cpp
+++ b/torchvision/csrc/ops/nms.cpp
@@ -1,6 +1,5 @@
 #include "nms.h"
 
-#include <torch/autograd.h>
 #include <torch/types.h>
 
 namespace vision {
diff --git a/torchvision/csrc/ops/ps_roi_align.cpp b/torchvision/csrc/ops/ps_roi_align.cpp
index 6092ee4cb68..53925e6f2ed 100644
--- a/torchvision/csrc/ops/ps_roi_align.cpp
+++ b/torchvision/csrc/ops/ps_roi_align.cpp
@@ -1,6 +1,5 @@
 #include "ps_roi_align.h"
 
-#include <torch/autograd.h>
 #include <torch/types.h>
 
 namespace vision {
@@ -20,6 +19,8 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align(
       input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
 }
 
+namespace detail {
+
 at::Tensor _ps_roi_align_backward(
     const at::Tensor& grad,
     const at::Tensor& rois,
@@ -50,6 +51,8 @@ at::Tensor _ps_roi_align_backward(
       width);
 }
 
+} // namespace detail
+
 TORCH_LIBRARY_FRAGMENT(torchvision, m) {
   m.def(
       "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)");
@@ -57,157 +60,5 @@ TORCH_LIBRARY_FRAGMENT(torchvision, m) {
       "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor");
 }
 
-namespace {
-
-class PSROIAlignFunction
-    : public torch::autograd::Function<PSROIAlignFunction> {
- public:
-  static torch::autograd::variable_list forward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::Variable& input,
-      const torch::autograd::Variable& rois,
-      double spatial_scale,
-      int64_t pooled_height,
-      int64_t pooled_width,
-      int64_t sampling_ratio) {
-    ctx->saved_data["spatial_scale"] = spatial_scale;
-    ctx->saved_data["pooled_height"] = pooled_height;
-    ctx->saved_data["pooled_width"] = pooled_width;
-    ctx->saved_data["sampling_ratio"] = sampling_ratio;
-    ctx->saved_data["input_shape"] = input.sizes();
-    at::AutoNonVariableTypeMode g;
-    auto result = ps_roi_align(
-        input,
-        rois,
-        spatial_scale,
-        pooled_height,
-        pooled_width,
-        sampling_ratio);
-
-    auto output = std::get<0>(result);
-    auto channel_mapping = std::get<1>(result);
-    ctx->save_for_backward({rois, channel_mapping});
-    ctx->mark_non_differentiable({channel_mapping});
-
-    return {output, channel_mapping};
-  }
-
-  static torch::autograd::variable_list backward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::variable_list& grad_output) {
-    // Use data saved in forward
-    auto saved = ctx->get_saved_variables();
-    auto rois = saved[0];
-    auto channel_mapping = saved[1];
-    auto input_shape = ctx->saved_data["input_shape"].toIntList();
-    auto grad_in = _ps_roi_align_backward(
-        grad_output[0],
-        rois,
-        channel_mapping,
-        ctx->saved_data["spatial_scale"].toDouble(),
-        ctx->saved_data["pooled_height"].toInt(),
-        ctx->saved_data["pooled_width"].toInt(),
-        ctx->saved_data["sampling_ratio"].toInt(),
-        input_shape[0],
-        input_shape[1],
-        input_shape[2],
-        input_shape[3]);
-
-    return {grad_in,
-            torch::autograd::Variable(),
-            torch::autograd::Variable(),
-            torch::autograd::Variable(),
-            torch::autograd::Variable(),
-            torch::autograd::Variable()};
-  }
-};
-
-// TODO: There should be an easier way to do this
-class PSROIAlignBackwardFunction
-    : public torch::autograd::Function<PSROIAlignBackwardFunction> {
- public:
-  static torch::autograd::variable_list forward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::Variable& grad,
-      const torch::autograd::Variable& rois,
-      const torch::autograd::Variable& channel_mapping,
-      double spatial_scale,
-      int64_t pooled_height,
-      int64_t pooled_width,
-      int64_t sampling_ratio,
-      int64_t batch_size,
-      int64_t channels,
-      int64_t height,
-      int64_t width) {
-    at::AutoNonVariableTypeMode g;
-    auto grad_in = _ps_roi_align_backward(
-        grad,
-        rois,
-        channel_mapping,
-        spatial_scale,
-        pooled_height,
-        pooled_width,
-        sampling_ratio,
-        batch_size,
-        channels,
-        height,
-        width);
-
-    return {grad_in};
-  }
-
-  static torch::autograd::variable_list backward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::variable_list& grad_output) {
-    TORCH_CHECK(0, "double backwards on ps_roi_align not supported");
-  }
-};
-
-std::tuple<at::Tensor, at::Tensor> ps_roi_align_autograd(
-    const at::Tensor& input,
-    const at::Tensor& rois,
-    double spatial_scale,
-    int64_t pooled_height,
-    int64_t pooled_width,
-    int64_t sampling_ratio) {
-  auto result = PSROIAlignFunction::apply(
-      input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
-
-  return std::make_tuple(result[0], result[1]);
-}
-
-at::Tensor ps_roi_align_backward_autograd(
-    const at::Tensor& grad,
-    const at::Tensor& rois,
-    const at::Tensor& channel_mapping,
-    double spatial_scale,
-    int64_t pooled_height,
-    int64_t pooled_width,
-    int64_t sampling_ratio,
-    int64_t batch_size,
-    int64_t channels,
-    int64_t height,
-    int64_t width) {
-  return PSROIAlignBackwardFunction::apply(
-      grad,
-      rois,
-      channel_mapping,
-      spatial_scale,
-      pooled_height,
-      pooled_width,
-      sampling_ratio,
-      batch_size,
-      channels,
-      height,
-      width)[0];
-}
-
-} // namespace
-
-TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
-  m.impl("ps_roi_align", ps_roi_align_autograd);
-  m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd);
-}
-
 } // namespace ops
 } // namespace vision
diff --git a/torchvision/csrc/ops/ps_roi_align.h b/torchvision/csrc/ops/ps_roi_align.h
index cee53455f0c..907662d2dfc 100644
--- a/torchvision/csrc/ops/ps_roi_align.h
+++ b/torchvision/csrc/ops/ps_roi_align.h
@@ -14,5 +14,22 @@ VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_align(
     int64_t pooled_width,
     int64_t sampling_ratio);
 
+namespace detail {
+
+VISION_API at::Tensor _ps_roi_align_backward(
+    const at::Tensor& grad,
+    const at::Tensor& rois,
+    const at::Tensor& channel_mapping,
+    double spatial_scale,
+    int64_t pooled_height,
+    int64_t pooled_width,
+    int64_t sampling_ratio,
+    int64_t batch_size,
+    int64_t channels,
+    int64_t height,
+    int64_t width);
+
+} // namespace detail
+
 } // namespace ops
 } // namespace vision
diff --git a/torchvision/csrc/ops/ps_roi_pool.cpp b/torchvision/csrc/ops/ps_roi_pool.cpp
index 2c87d0c4fe8..1bc7df233d1 100644
--- a/torchvision/csrc/ops/ps_roi_pool.cpp
+++ b/torchvision/csrc/ops/ps_roi_pool.cpp
@@ -1,6 +1,5 @@
 #include "ps_roi_pool.h"
 
-#include <torch/autograd.h>
 #include <torch/types.h>
 
 namespace vision {
@@ -18,6 +17,8 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
   return op.call(input, rois, spatial_scale, pooled_height, pooled_width);
 }
 
+namespace detail {
+
 at::Tensor _ps_roi_pool_backward(
     const at::Tensor& grad,
     const at::Tensor& rois,
@@ -46,6 +47,8 @@ at::Tensor _ps_roi_pool_backward(
       width);
 }
 
+} // namespace detail
+
 TORCH_LIBRARY_FRAGMENT(torchvision, m) {
   m.def(
       "ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
@@ -53,142 +56,5 @@ TORCH_LIBRARY_FRAGMENT(torchvision, m) {
       "_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
 }
 
-namespace {
-
-class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
- public:
-  static torch::autograd::variable_list forward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::Variable& input,
-      const torch::autograd::Variable& rois,
-      double spatial_scale,
-      int64_t pooled_height,
-      int64_t pooled_width) {
-    ctx->saved_data["spatial_scale"] = spatial_scale;
-    ctx->saved_data["pooled_height"] = pooled_height;
-    ctx->saved_data["pooled_width"] = pooled_width;
-    ctx->saved_data["input_shape"] = input.sizes();
-    at::AutoNonVariableTypeMode g;
-    auto result =
-        ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width);
-
-    auto output = std::get<0>(result);
-    auto channel_mapping = std::get<1>(result);
-    ctx->save_for_backward({rois, channel_mapping});
-    ctx->mark_non_differentiable({channel_mapping});
-
-    return {output, channel_mapping};
-  }
-
-  static torch::autograd::variable_list backward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::variable_list& grad_output) {
-    // Use data saved in forward
-    auto saved = ctx->get_saved_variables();
-    auto rois = saved[0];
-    auto channel_mapping = saved[1];
-    auto input_shape = ctx->saved_data["input_shape"].toIntList();
-    auto grad_in = _ps_roi_pool_backward(
-        grad_output[0],
-        rois,
-        channel_mapping,
-        ctx->saved_data["spatial_scale"].toDouble(),
-        ctx->saved_data["pooled_height"].toInt(),
-        ctx->saved_data["pooled_width"].toInt(),
-        input_shape[0],
-        input_shape[1],
-        input_shape[2],
-        input_shape[3]);
-
-    return {grad_in,
-            torch::autograd::Variable(),
-            torch::autograd::Variable(),
-            torch::autograd::Variable(),
-            torch::autograd::Variable()};
-  }
-};
-
-// TODO: There should be an easier way to do this
-class PSROIPoolBackwardFunction
-    : public torch::autograd::Function<PSROIPoolBackwardFunction> {
- public:
-  static torch::autograd::variable_list forward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::Variable& grad,
-      const torch::autograd::Variable& rois,
-      const torch::autograd::Variable& channel_mapping,
-      double spatial_scale,
-      int64_t pooled_height,
-      int64_t pooled_width,
-      int64_t batch_size,
-      int64_t channels,
-      int64_t height,
-      int64_t width) {
-    at::AutoNonVariableTypeMode g;
-    auto grad_in = _ps_roi_pool_backward(
-        grad,
-        rois,
-        channel_mapping,
-        spatial_scale,
-        pooled_height,
-        pooled_width,
-        batch_size,
-        channels,
-        height,
-        width);
-
-    return {grad_in};
-  }
-
-  static torch::autograd::variable_list backward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::variable_list& grad_output) {
-    TORCH_CHECK(0, "double backwards on ps_roi_pool not supported");
-  }
-};
-
-std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autograd(
-    const at::Tensor& input,
-    const at::Tensor& rois,
-    double spatial_scale,
-    int64_t pooled_height,
-    int64_t pooled_width) {
-  auto result = PSROIPoolFunction::apply(
-      input, rois, spatial_scale, pooled_height, pooled_width);
-
-  return std::make_tuple(result[0], result[1]);
-}
-
-at::Tensor ps_roi_pool_backward_autograd(
-    const at::Tensor& grad,
-    const at::Tensor& rois,
-    const at::Tensor& channel_mapping,
-    double spatial_scale,
-    int64_t pooled_height,
-    int64_t pooled_width,
-    int64_t batch_size,
-    int64_t channels,
-    int64_t height,
-    int64_t width) {
-  return PSROIPoolBackwardFunction::apply(
-      grad,
-      rois,
-      channel_mapping,
-      spatial_scale,
-      pooled_height,
-      pooled_width,
-      batch_size,
-      channels,
-      height,
-      width)[0];
-}
-
-} // namespace
-
-TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
-  m.impl("ps_roi_pool", ps_roi_pool_autograd);
-  m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_autograd);
-}
-
 } // namespace ops
 } // namespace vision
diff --git a/torchvision/csrc/ops/ps_roi_pool.h b/torchvision/csrc/ops/ps_roi_pool.h
index 2c2bab314d0..7de748ce0ff 100644
--- a/torchvision/csrc/ops/ps_roi_pool.h
+++ b/torchvision/csrc/ops/ps_roi_pool.h
@@ -13,5 +13,21 @@ VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
     int64_t pooled_height,
     int64_t pooled_width);
 
+namespace detail {
+
+VISION_API at::Tensor _ps_roi_pool_backward(
+    const at::Tensor& grad,
+    const at::Tensor& rois,
+    const at::Tensor& channel_mapping,
+    double spatial_scale,
+    int64_t pooled_height,
+    int64_t pooled_width,
+    int64_t batch_size,
+    int64_t channels,
+    int64_t height,
+    int64_t width);
+
+} // namespace detail
+
 } // namespace ops
 } // namespace vision
diff --git a/torchvision/csrc/ops/roi_align.cpp b/torchvision/csrc/ops/roi_align.cpp
index 7116cc422ba..c4b0cc0167d 100644
--- a/torchvision/csrc/ops/roi_align.cpp
+++ b/torchvision/csrc/ops/roi_align.cpp
@@ -1,6 +1,5 @@
 #include "roi_align.h"
 
-#include <torch/autograd.h>
 #include <torch/types.h>
 
 namespace vision {
@@ -30,6 +29,8 @@ at::Tensor roi_align(
       aligned);
 }
 
+namespace detail {
+
 at::Tensor _roi_align_backward(
     const at::Tensor& grad,
     const at::Tensor& rois,
@@ -60,6 +61,8 @@ at::Tensor _roi_align_backward(
       aligned);
 }
 
+} // namespace detail
+
 TORCH_LIBRARY_FRAGMENT(torchvision, m) {
   m.def(
       "roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor");
@@ -67,157 +70,5 @@ TORCH_LIBRARY_FRAGMENT(torchvision, m) {
       "_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor");
 }
 
-namespace {
-
-class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
- public:
-  static torch::autograd::variable_list forward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::Variable& input,
-      const torch::autograd::Variable& rois,
-      double spatial_scale,
-      int64_t pooled_height,
-      int64_t pooled_width,
-      int64_t sampling_ratio,
-      bool aligned) {
-    ctx->saved_data["spatial_scale"] = spatial_scale;
-    ctx->saved_data["pooled_height"] = pooled_height;
-    ctx->saved_data["pooled_width"] = pooled_width;
-    ctx->saved_data["sampling_ratio"] = sampling_ratio;
-    ctx->saved_data["aligned"] = aligned;
-    ctx->saved_data["input_shape"] = input.sizes();
-    ctx->save_for_backward({rois});
-    at::AutoNonVariableTypeMode g;
-    auto result = roi_align(
-        input,
-        rois,
-        spatial_scale,
-        pooled_height,
-        pooled_width,
-        sampling_ratio,
-        aligned);
-    return {result};
-  }
-
-  static torch::autograd::variable_list backward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::variable_list& grad_output) {
-    // Use data saved in forward
-    auto saved = ctx->get_saved_variables();
-    auto rois = saved[0];
-    auto input_shape = ctx->saved_data["input_shape"].toIntList();
-    auto grad_in = _roi_align_backward(
-        grad_output[0],
-        rois,
-        ctx->saved_data["spatial_scale"].toDouble(),
-        ctx->saved_data["pooled_height"].toInt(),
-        ctx->saved_data["pooled_width"].toInt(),
-        input_shape[0],
-        input_shape[1],
-        input_shape[2],
-        input_shape[3],
-        ctx->saved_data["sampling_ratio"].toInt(),
-        ctx->saved_data["aligned"].toBool());
-    return {grad_in,
-            torch::autograd::Variable(),
-            torch::autograd::Variable(),
-            torch::autograd::Variable(),
-            torch::autograd::Variable(),
-            torch::autograd::Variable(),
-            torch::autograd::Variable()};
-  }
-};
-
-// TODO: There should be an easier way to do this
-class ROIAlignBackwardFunction
-    : public torch::autograd::Function<ROIAlignBackwardFunction> {
- public:
-  static torch::autograd::variable_list forward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::Variable& grad,
-      const torch::autograd::Variable& rois,
-      double spatial_scale,
-      int64_t pooled_height,
-      int64_t pooled_width,
-      int64_t batch_size,
-      int64_t channels,
-      int64_t height,
-      int64_t width,
-      int64_t sampling_ratio,
-      bool aligned) {
-    at::AutoNonVariableTypeMode g;
-    auto result = _roi_align_backward(
-        grad,
-        rois,
-        spatial_scale,
-        pooled_height,
-        pooled_width,
-        batch_size,
-        channels,
-        height,
-        width,
-        sampling_ratio,
-        aligned);
-    return {result};
-  }
-
-  static torch::autograd::variable_list backward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::variable_list& grad_output) {
-    TORCH_CHECK(0, "double backwards on roi_align not supported");
-  }
-};
-
-at::Tensor roi_align_autograd(
-    const at::Tensor& input,
-    const at::Tensor& rois,
-    double spatial_scale,
-    int64_t pooled_height,
-    int64_t pooled_width,
-    int64_t sampling_ratio,
-    bool aligned) {
-  return ROIAlignFunction::apply(
-      input,
-      rois,
-      spatial_scale,
-      pooled_height,
-      pooled_width,
-      sampling_ratio,
-      aligned)[0];
-}
-
-at::Tensor roi_align_backward_autograd(
-    const at::Tensor& grad,
-    const at::Tensor& rois,
-    double spatial_scale,
-    int64_t pooled_height,
-    int64_t pooled_width,
-    int64_t batch_size,
-    int64_t channels,
-    int64_t height,
-    int64_t width,
-    int64_t sampling_ratio,
-    bool aligned) {
-  return ROIAlignBackwardFunction::apply(
-      grad,
-      rois,
-      spatial_scale,
-      pooled_height,
-      pooled_width,
-      batch_size,
-      channels,
-      height,
-      width,
-      sampling_ratio,
-      aligned)[0];
-}
-
-} // namespace
-
-TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
-  m.impl("roi_align", roi_align_autograd);
-  m.impl("_roi_align_backward", roi_align_backward_autograd);
-}
-
 } // namespace ops
 } // namespace vision
diff --git a/torchvision/csrc/ops/roi_align.h b/torchvision/csrc/ops/roi_align.h
index 628c0fd33dd..33ff2d18fd1 100644
--- a/torchvision/csrc/ops/roi_align.h
+++ b/torchvision/csrc/ops/roi_align.h
@@ -15,5 +15,22 @@ VISION_API at::Tensor roi_align(
     int64_t sampling_ratio,
     bool aligned);
 
+namespace detail {
+
+VISION_API at::Tensor _roi_align_backward(
+    const at::Tensor& grad,
+    const at::Tensor& rois,
+    double spatial_scale,
+    int64_t pooled_height,
+    int64_t pooled_width,
+    int64_t batch_size,
+    int64_t channels,
+    int64_t height,
+    int64_t width,
+    int64_t sampling_ratio,
+    bool aligned);
+
+} // namespace detail
+
 } // namespace ops
 } // namespace vision
diff --git a/torchvision/csrc/ops/roi_pool.cpp b/torchvision/csrc/ops/roi_pool.cpp
index 237c5e9adf8..c2fcb459e3f 100644
--- a/torchvision/csrc/ops/roi_pool.cpp
+++ b/torchvision/csrc/ops/roi_pool.cpp
@@ -1,6 +1,5 @@
 #include "roi_pool.h"
 
-#include <torch/autograd.h>
 #include <torch/types.h>
 
 namespace vision {
@@ -18,6 +17,8 @@ std::tuple<at::Tensor, at::Tensor> roi_pool(
   return op.call(input, rois, spatial_scale, pooled_height, pooled_width);
 }
 
+namespace detail {
+
 at::Tensor _roi_pool_backward(
     const at::Tensor& grad,
     const at::Tensor& rois,
@@ -45,6 +46,8 @@ at::Tensor _roi_pool_backward(
       width);
 }
 
+} // namespace detail
+
 TORCH_LIBRARY_FRAGMENT(torchvision, m) {
   m.def(
       "roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
@@ -52,142 +55,5 @@ TORCH_LIBRARY_FRAGMENT(torchvision, m) {
       "_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
 }
 
-namespace {
-
-class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
- public:
-  static torch::autograd::variable_list forward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::Variable& input,
-      const torch::autograd::Variable& rois,
-      double spatial_scale,
-      int64_t pooled_height,
-      int64_t pooled_width) {
-    ctx->saved_data["spatial_scale"] = spatial_scale;
-    ctx->saved_data["pooled_height"] = pooled_height;
-    ctx->saved_data["pooled_width"] = pooled_width;
-    ctx->saved_data["input_shape"] = input.sizes();
-    at::AutoNonVariableTypeMode g;
-    auto result =
-        roi_pool(input, rois, spatial_scale, pooled_height, pooled_width);
-
-    auto output = std::get<0>(result);
-    auto argmax = std::get<1>(result);
-    ctx->save_for_backward({rois, argmax});
-    ctx->mark_non_differentiable({argmax});
-
-    return {output, argmax};
-  }
-
-  static torch::autograd::variable_list backward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::variable_list& grad_output) {
-    // Use data saved in forward
-    auto saved = ctx->get_saved_variables();
-    auto rois = saved[0];
-    auto argmax = saved[1];
-    auto input_shape = ctx->saved_data["input_shape"].toIntList();
-    auto grad_in = _roi_pool_backward(
-        grad_output[0],
-        rois,
-        argmax,
-        ctx->saved_data["spatial_scale"].toDouble(),
-        ctx->saved_data["pooled_height"].toInt(),
-        ctx->saved_data["pooled_width"].toInt(),
-        input_shape[0],
-        input_shape[1],
-        input_shape[2],
-        input_shape[3]);
-
-    return {grad_in,
-            torch::autograd::Variable(),
-            torch::autograd::Variable(),
-            torch::autograd::Variable(),
-            torch::autograd::Variable()};
-  }
-};
-
-// TODO: There should be an easier way to do this
-class ROIPoolBackwardFunction
-    : public torch::autograd::Function<ROIPoolBackwardFunction> {
- public:
-  static torch::autograd::variable_list forward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::Variable& grad,
-      const torch::autograd::Variable& rois,
-      const torch::autograd::Variable& argmax,
-      double spatial_scale,
-      int64_t pooled_height,
-      int64_t pooled_width,
-      int64_t batch_size,
-      int64_t channels,
-      int64_t height,
-      int64_t width) {
-    at::AutoNonVariableTypeMode g;
-    auto grad_in = _roi_pool_backward(
-        grad,
-        rois,
-        argmax,
-        spatial_scale,
-        pooled_height,
-        pooled_width,
-        batch_size,
-        channels,
-        height,
-        width);
-
-    return {grad_in};
-  }
-
-  static torch::autograd::variable_list backward(
-      torch::autograd::AutogradContext* ctx,
-      const torch::autograd::variable_list& grad_output) {
-    TORCH_CHECK(0, "double backwards on roi_pool not supported");
-  }
-};
-
-std::tuple<at::Tensor, at::Tensor> roi_pool_autograd(
-    const at::Tensor& input,
-    const at::Tensor& rois,
-    double spatial_scale,
-    int64_t pooled_height,
-    int64_t pooled_width) {
-  auto result = ROIPoolFunction::apply(
-      input, rois, spatial_scale, pooled_height, pooled_width);
-
-  return std::make_tuple(result[0], result[1]);
-}
-
-at::Tensor roi_pool_backward_autograd(
-    const at::Tensor& grad,
-    const at::Tensor& rois,
-    const at::Tensor& argmax,
-    double spatial_scale,
-    int64_t pooled_height,
-    int64_t pooled_width,
-    int64_t batch_size,
-    int64_t channels,
-    int64_t height,
-    int64_t width) {
-  return ROIPoolBackwardFunction::apply(
-      grad,
-      rois,
-      argmax,
-      spatial_scale,
-      pooled_height,
-      pooled_width,
-      batch_size,
-      channels,
-      height,
-      width)[0];
-}
-
-} // namespace
-
-TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
-  m.impl("roi_pool", roi_pool_autograd);
-  m.impl("_roi_pool_backward", roi_pool_backward_autograd);
-}
-
 } // namespace ops
 } // namespace vision
diff --git a/torchvision/csrc/ops/roi_pool.h b/torchvision/csrc/ops/roi_pool.h
index 5ecdb711e1b..6994e68fb4a 100644
--- a/torchvision/csrc/ops/roi_pool.h
+++ b/torchvision/csrc/ops/roi_pool.h
@@ -13,5 +13,21 @@ VISION_API std::tuple<at::Tensor, at::Tensor> roi_pool(
     int64_t pooled_height,
     int64_t pooled_width);
 
+namespace detail {
+
+VISION_API at::Tensor _roi_pool_backward(
+    const at::Tensor& grad,
+    const at::Tensor& rois,
+    const at::Tensor& argmax,
+    double spatial_scale,
+    int64_t pooled_height,
+    int64_t pooled_width,
+    int64_t batch_size,
+    int64_t channels,
+    int64_t height,
+    int64_t width);
+
+} // namespace detail
+
 } // namespace ops
 } // namespace vision