Skip to content

Commit

Permalink
Per file C++ Operator registration (#3135)
Browse files Browse the repository at this point in the history
* Moving deform_conv2d op registration.

* Moving nms op registration.

* Moving new_empty_tensor op registration.

* Moving ps_roi_align op registration.

* Moving ps_roi_pool op registration.

* Moving roi_align op registration.

* Moving roi_pool op registration.

* Restoring headers for forward/backward and fixing styles.

* Restoring the test hack on windows.

* Stricter header inclusion.
  • Loading branch information
datumbox authored Dec 8, 2020
1 parent 6cb4fc2 commit 3c33f36
Show file tree
Hide file tree
Showing 40 changed files with 306 additions and 804 deletions.
4 changes: 2 additions & 2 deletions test/tracing/frcnn/test_frcnn_tracing.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#include <ATen/ATen.h>
#include <torch/script.h>
#include <torch/torch.h>
#include <torchvision/roi_align.h>
#include <torchvision/nms.h>
#include <torchvision/roi_align.h>

#ifdef _WIN32
// Windows only
// This is necessary until operators are automatically registered on include
static auto _nms = &vision::ops::nms_cpu;
static auto _nms = &vision::ops::nms;
#endif

int main() {
Expand Down
16 changes: 11 additions & 5 deletions torchvision/csrc/cpu/deform_conv2d_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@
// modified from
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp

#include "deform_conv2d_kernel.h"
#include <ATen/ATen.h>
#include <torch/library.h>

namespace vision {
namespace ops {
Expand Down Expand Up @@ -852,9 +853,7 @@ at::Tensor backward_gradient_parameters(
return grad_weight;
}

} // namespace

at::Tensor deform_conv2d_forward_cpu(
at::Tensor deform_conv2d_forward_kernel(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
Expand Down Expand Up @@ -1070,7 +1069,7 @@ at::Tensor deform_conv2d_forward_cpu(
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
deform_conv2d_backward_cpu(
deform_conv2d_backward_kernel(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
Expand Down Expand Up @@ -1141,5 +1140,12 @@ deform_conv2d_backward_cpu(
grad_input, grad_weight, grad_offset, grad_mask, grad_bias);
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("deform_conv2d", deform_conv2d_forward_kernel);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_kernel);
}

} // namespace ops
} // namespace vision
45 changes: 0 additions & 45 deletions torchvision/csrc/cpu/deform_conv2d_kernel.h

This file was deleted.

15 changes: 10 additions & 5 deletions torchvision/csrc/cpu/nms_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "nms_kernel.h"
#include <ATen/ATen.h>
#include <torch/library.h>

namespace vision {
namespace ops {
Expand Down Expand Up @@ -74,9 +75,7 @@ at::Tensor nms_kernel_impl(
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
}

} // namespace

at::Tensor nms_cpu(
at::Tensor nms_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold) {
Expand All @@ -101,11 +100,17 @@ at::Tensor nms_cpu(

auto result = at::empty({0}, dets.options());

AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_cpu", [&] {
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] {
result = nms_kernel_impl<scalar_t>(dets, scores, iou_threshold);
});
return result;
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("nms", nms_kernel);
}

} // namespace ops
} // namespace vision
15 changes: 0 additions & 15 deletions torchvision/csrc/cpu/nms_kernel.h

This file was deleted.

24 changes: 15 additions & 9 deletions torchvision/csrc/cpu/ps_roi_align_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "ps_roi_align_kernel.h"
#include <ATen/ATen.h>
#include <torch/library.h>

namespace vision {
namespace ops {
Expand Down Expand Up @@ -301,9 +302,7 @@ void ps_roi_align_backward_kernel_impl(
}
}

} // namespace

std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cpu(
std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_kernel(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
Expand All @@ -318,7 +317,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cpu(

at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};

at::CheckedFrom c = "ps_roi_align_forward_cpu";
at::CheckedFrom c = "ps_roi_align_forward_kernel";
at::checkAllSameType(c, {input_t, rois_t});

int num_rois = rois.size(0);
Expand All @@ -343,7 +342,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cpu(

auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "ps_roi_align_forward_cpu", [&] {
input.scalar_type(), "ps_roi_align_forward_kernel", [&] {
ps_roi_align_forward_kernel_impl<scalar_t>(
output_size,
input_.data_ptr<scalar_t>(),
Expand All @@ -362,7 +361,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cpu(
return std::make_tuple(output, channel_mapping);
}

at::Tensor ps_roi_align_backward_cpu(
at::Tensor ps_roi_align_backward_kernel(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
Expand All @@ -384,7 +383,7 @@ at::Tensor ps_roi_align_backward_cpu(
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};

at::CheckedFrom c = "ps_roi_align_backward_cpu";
at::CheckedFrom c = "ps_roi_align_backward_kernel";
at::checkAllSameType(c, {grad_t, rois_t});

auto num_rois = rois.size(0);
Expand All @@ -400,7 +399,7 @@ at::Tensor ps_roi_align_backward_cpu(

auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "ps_roi_align_backward_cpu", [&] {
grad.scalar_type(), "ps_roi_align_backward_kernel", [&] {
ps_roi_align_backward_kernel_impl<scalar_t>(
grad.numel(),
grad_.data_ptr<scalar_t>(),
Expand All @@ -420,5 +419,12 @@ at::Tensor ps_roi_align_backward_cpu(
return grad_input;
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("ps_roi_align", ps_roi_align_forward_kernel);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_kernel);
}

} // namespace ops
} // namespace vision
31 changes: 0 additions & 31 deletions torchvision/csrc/cpu/ps_roi_align_kernel.h

This file was deleted.

24 changes: 15 additions & 9 deletions torchvision/csrc/cpu/ps_roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "ps_roi_pool_kernel.h"
#include <ATen/ATen.h>
#include <torch/library.h>

namespace vision {
namespace ops {
Expand Down Expand Up @@ -145,9 +146,7 @@ void ps_roi_pool_backward_kernel_impl(
}
}

} // namespace

std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_cpu(
std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_kernel(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
Expand All @@ -161,7 +160,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_cpu(

at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};

at::CheckedFrom c = "ps_roi_pool_forward_cpu";
at::CheckedFrom c = "ps_roi_pool_forward_kernel";
at::checkAllSameType(c, {input_t, rois_t});

int num_rois = rois.size(0);
Expand All @@ -186,7 +185,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_cpu(

auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "ps_roi_pool_forward_cpu", [&] {
input.scalar_type(), "ps_roi_pool_forward_kernel", [&] {
ps_roi_pool_forward_kernel_impl<scalar_t>(
input_.data_ptr<scalar_t>(),
spatial_scale,
Expand All @@ -204,7 +203,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_cpu(
return std::make_tuple(output, channel_mapping);
}

at::Tensor ps_roi_pool_backward_cpu(
at::Tensor ps_roi_pool_backward_kernel(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
Expand All @@ -225,7 +224,7 @@ at::Tensor ps_roi_pool_backward_cpu(
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};

at::CheckedFrom c = "ps_roi_pool_backward_cpu";
at::CheckedFrom c = "ps_roi_pool_backward_kernel";
at::checkAllSameType(c, {grad_t, rois_t});

auto num_rois = rois.size(0);
Expand All @@ -241,7 +240,7 @@ at::Tensor ps_roi_pool_backward_cpu(

auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "ps_roi_pool_backward_cpu", [&] {
grad.scalar_type(), "ps_roi_pool_backward_kernel", [&] {
ps_roi_pool_backward_kernel_impl<scalar_t>(
grad_.data_ptr<scalar_t>(),
channel_mapping.data_ptr<int>(),
Expand All @@ -259,5 +258,12 @@ at::Tensor ps_roi_pool_backward_cpu(
return grad_input;
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("ps_roi_pool", ps_roi_pool_forward_kernel);
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_kernel);
}

} // namespace ops
} // namespace vision
29 changes: 0 additions & 29 deletions torchvision/csrc/cpu/ps_roi_pool_kernel.h

This file was deleted.

Loading

0 comments on commit 3c33f36

Please sign in to comment.