diff --git a/test/tracing/frcnn/test_frcnn_tracing.cpp b/test/tracing/frcnn/test_frcnn_tracing.cpp index 7b10aee3c89..bd98c7866d8 100644 --- a/test/tracing/frcnn/test_frcnn_tracing.cpp +++ b/test/tracing/frcnn/test_frcnn_tracing.cpp @@ -1,13 +1,13 @@ #include #include #include -#include #include +#include #ifdef _WIN32 // Windows only // This is necessary until operators are automatically registered on include -static auto _nms = &vision::ops::nms_cpu; +static auto _nms = &vision::ops::nms; #endif int main() { diff --git a/torchvision/csrc/cpu/deform_conv2d_kernel.cpp b/torchvision/csrc/cpu/deform_conv2d_kernel.cpp index 4ae2d0a02db..d3f04cafae8 100644 --- a/torchvision/csrc/cpu/deform_conv2d_kernel.cpp +++ b/torchvision/csrc/cpu/deform_conv2d_kernel.cpp @@ -66,7 +66,8 @@ // modified from // https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp -#include "deform_conv2d_kernel.h" +#include +#include namespace vision { namespace ops { @@ -852,9 +853,7 @@ at::Tensor backward_gradient_parameters( return grad_weight; } -} // namespace - -at::Tensor deform_conv2d_forward_cpu( +at::Tensor deform_conv2d_forward_kernel( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& offset, @@ -1070,7 +1069,7 @@ at::Tensor deform_conv2d_forward_cpu( } std::tuple -deform_conv2d_backward_cpu( +deform_conv2d_backward_kernel( const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& weight, @@ -1141,5 +1140,12 @@ deform_conv2d_backward_cpu( grad_input, grad_weight, grad_offset, grad_mask, grad_bias); } +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl("deform_conv2d", deform_conv2d_forward_kernel); + m.impl("_deform_conv2d_backward", deform_conv2d_backward_kernel); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/cpu/deform_conv2d_kernel.h b/torchvision/csrc/cpu/deform_conv2d_kernel.h deleted file mode 100644 index 2a49bad8304..00000000000 --- a/torchvision/csrc/cpu/deform_conv2d_kernel.h +++ /dev/null @@ -1,45 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API at::Tensor deform_conv2d_forward_cpu( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - bool use_mask); - -VISION_API std:: - tuple - deform_conv2d_backward_cpu( - const at::Tensor& grad_out, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - bool use_mask); - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/cpu/nms_kernel.cpp b/torchvision/csrc/cpu/nms_kernel.cpp index a77a6906870..1bd64b10296 100644 --- a/torchvision/csrc/cpu/nms_kernel.cpp +++ b/torchvision/csrc/cpu/nms_kernel.cpp @@ -1,4 +1,5 @@ -#include "nms_kernel.h" +#include +#include namespace vision { namespace ops { @@ -74,9 +75,7 @@ at::Tensor nms_kernel_impl( return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep); } -} // namespace - -at::Tensor nms_cpu( +at::Tensor nms_kernel( const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) { @@ -101,11 +100,17 @@ at::Tensor nms_cpu( auto result = at::empty({0}, dets.options()); - AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_cpu", [&] { + AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] { result = nms_kernel_impl(dets, scores, iou_threshold); }); return result; } +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl("nms", nms_kernel); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/cpu/nms_kernel.h b/torchvision/csrc/cpu/nms_kernel.h deleted file mode 100644 index 1fdcaf3d3f9..00000000000 --- a/torchvision/csrc/cpu/nms_kernel.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API at::Tensor nms_cpu( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold); - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/cpu/ps_roi_align_kernel.cpp b/torchvision/csrc/cpu/ps_roi_align_kernel.cpp index 5abe4a41477..5c4f978ae87 100644 --- a/torchvision/csrc/cpu/ps_roi_align_kernel.cpp +++ b/torchvision/csrc/cpu/ps_roi_align_kernel.cpp @@ -1,4 +1,5 @@ -#include "ps_roi_align_kernel.h" +#include +#include namespace vision { namespace ops { @@ -301,9 +302,7 @@ void ps_roi_align_backward_kernel_impl( } } -} // namespace - -std::tuple ps_roi_align_forward_cpu( +std::tuple ps_roi_align_forward_kernel( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -318,7 +317,7 @@ std::tuple ps_roi_align_forward_cpu( at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "ps_roi_align_forward_cpu"; + at::CheckedFrom c = "ps_roi_align_forward_kernel"; at::checkAllSameType(c, {input_t, rois_t}); int num_rois = rois.size(0); @@ -343,7 +342,7 @@ std::tuple ps_roi_align_forward_cpu( auto input_ = input.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "ps_roi_align_forward_cpu", [&] { + input.scalar_type(), "ps_roi_align_forward_kernel", [&] { ps_roi_align_forward_kernel_impl( output_size, input_.data_ptr(), @@ -362,7 +361,7 @@ std::tuple ps_roi_align_forward_cpu( return std::make_tuple(output, channel_mapping); } -at::Tensor ps_roi_align_backward_cpu( +at::Tensor ps_roi_align_backward_kernel( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, @@ -384,7 +383,7 @@ at::Tensor ps_roi_align_backward_cpu( at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, channel_mapping_t{channel_mapping, "channel_mapping", 3}; - at::CheckedFrom c = "ps_roi_align_backward_cpu"; + at::CheckedFrom c = "ps_roi_align_backward_kernel"; at::checkAllSameType(c, {grad_t, rois_t}); auto num_rois = rois.size(0); @@ -400,7 +399,7 @@ at::Tensor ps_roi_align_backward_cpu( auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "ps_roi_align_backward_cpu", [&] { + grad.scalar_type(), "ps_roi_align_backward_kernel", [&] { ps_roi_align_backward_kernel_impl( grad.numel(), grad_.data_ptr(), @@ -420,5 +419,12 @@ at::Tensor ps_roi_align_backward_cpu( return grad_input; } +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl("ps_roi_align", ps_roi_align_forward_kernel); + m.impl("_ps_roi_align_backward", ps_roi_align_backward_kernel); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/cpu/ps_roi_align_kernel.h b/torchvision/csrc/cpu/ps_roi_align_kernel.h deleted file mode 100644 index a4bea77853b..00000000000 --- a/torchvision/csrc/cpu/ps_roi_align_kernel.h +++ /dev/null @@ -1,31 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API std::tuple ps_roi_align_forward_cpu( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio); - -VISION_API at::Tensor ps_roi_align_backward_cpu( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/cpu/ps_roi_pool_kernel.cpp b/torchvision/csrc/cpu/ps_roi_pool_kernel.cpp index 425b4c68f1a..e20c20869f2 100644 --- a/torchvision/csrc/cpu/ps_roi_pool_kernel.cpp +++ b/torchvision/csrc/cpu/ps_roi_pool_kernel.cpp @@ -1,4 +1,5 @@ -#include "ps_roi_pool_kernel.h" +#include +#include namespace vision { namespace ops { @@ -145,9 +146,7 @@ void ps_roi_pool_backward_kernel_impl( } } -} // namespace - -std::tuple ps_roi_pool_forward_cpu( +std::tuple ps_roi_pool_forward_kernel( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -161,7 +160,7 @@ std::tuple ps_roi_pool_forward_cpu( at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "ps_roi_pool_forward_cpu"; + at::CheckedFrom c = "ps_roi_pool_forward_kernel"; at::checkAllSameType(c, {input_t, rois_t}); int num_rois = rois.size(0); @@ -186,7 +185,7 @@ std::tuple ps_roi_pool_forward_cpu( auto input_ = input.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "ps_roi_pool_forward_cpu", [&] { + input.scalar_type(), "ps_roi_pool_forward_kernel", [&] { ps_roi_pool_forward_kernel_impl( input_.data_ptr(), spatial_scale, @@ -204,7 +203,7 @@ std::tuple ps_roi_pool_forward_cpu( return std::make_tuple(output, channel_mapping); } -at::Tensor ps_roi_pool_backward_cpu( +at::Tensor ps_roi_pool_backward_kernel( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, @@ -225,7 +224,7 @@ at::Tensor ps_roi_pool_backward_cpu( at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, channel_mapping_t{channel_mapping, "channel_mapping", 3}; - at::CheckedFrom c = "ps_roi_pool_backward_cpu"; + at::CheckedFrom c = "ps_roi_pool_backward_kernel"; at::checkAllSameType(c, {grad_t, rois_t}); auto num_rois = rois.size(0); @@ -241,7 +240,7 @@ at::Tensor ps_roi_pool_backward_cpu( auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "ps_roi_pool_backward_cpu", [&] { + grad.scalar_type(), "ps_roi_pool_backward_kernel", [&] { ps_roi_pool_backward_kernel_impl( grad_.data_ptr(), channel_mapping.data_ptr(), @@ -259,5 +258,12 @@ at::Tensor ps_roi_pool_backward_cpu( return grad_input; } +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl("ps_roi_pool", ps_roi_pool_forward_kernel); + m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_kernel); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/cpu/ps_roi_pool_kernel.h b/torchvision/csrc/cpu/ps_roi_pool_kernel.h deleted file mode 100644 index 2cefe39e11e..00000000000 --- a/torchvision/csrc/cpu/ps_roi_pool_kernel.h +++ /dev/null @@ -1,29 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API std::tuple ps_roi_pool_forward_cpu( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -VISION_API at::Tensor ps_roi_pool_backward_cpu( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/cpu/roi_align_kernel.cpp b/torchvision/csrc/cpu/roi_align_kernel.cpp index cbb75f2c474..6b6f36d7d7a 100644 --- a/torchvision/csrc/cpu/roi_align_kernel.cpp +++ b/torchvision/csrc/cpu/roi_align_kernel.cpp @@ -1,4 +1,5 @@ -#include "roi_align_kernel.h" +#include +#include namespace vision { namespace ops { @@ -388,9 +389,7 @@ void roi_align_backward_kernel_impl( } // for } -} // namespace - -at::Tensor roi_align_forward_cpu( +at::Tensor roi_align_forward_kernel( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -404,7 +403,7 @@ at::Tensor roi_align_forward_cpu( at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "roi_align_forward_cpu"; + at::CheckedFrom c = "roi_align_forward_kernel"; at::checkAllSameType(c, {input_t, rois_t}); auto num_rois = rois.size(0); @@ -422,7 +421,7 @@ at::Tensor roi_align_forward_cpu( auto input_ = input.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "roi_align_forward_cpu", [&] { + input.scalar_type(), "roi_align_forward_kernel", [&] { roi_align_forward_kernel_impl( output_size, input_.data_ptr(), @@ -440,7 +439,7 @@ at::Tensor roi_align_forward_cpu( return output; } -at::Tensor roi_align_backward_cpu( +at::Tensor roi_align_backward_kernel( const at::Tensor& grad, const at::Tensor& rois, double spatial_scale, @@ -457,7 +456,7 @@ at::Tensor roi_align_backward_cpu( at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "roi_align_backward_cpu"; + at::CheckedFrom c = "roi_align_backward_kernel"; at::checkAllSameType(c, {grad_t, rois_t}); at::Tensor grad_input = @@ -476,7 +475,7 @@ at::Tensor roi_align_backward_cpu( auto rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "roi_align_backward_cpu", [&] { + grad.scalar_type(), "roi_align_backward_kernel", [&] { roi_align_backward_kernel_impl( grad.numel(), grad.data_ptr(), @@ -498,5 +497,12 @@ at::Tensor roi_align_backward_cpu( return grad_input; } +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl("roi_align", roi_align_forward_kernel); + m.impl("_roi_align_backward", roi_align_backward_kernel); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/cpu/roi_align_kernel.h b/torchvision/csrc/cpu/roi_align_kernel.h deleted file mode 100644 index 2e7813c261c..00000000000 --- a/torchvision/csrc/cpu/roi_align_kernel.h +++ /dev/null @@ -1,32 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API at::Tensor roi_align_forward_cpu( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - bool aligned); - -VISION_API at::Tensor roi_align_backward_cpu( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned); - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/cpu/roi_pool_kernel.cpp b/torchvision/csrc/cpu/roi_pool_kernel.cpp index 375b722684e..c600778e16e 100644 --- a/torchvision/csrc/cpu/roi_pool_kernel.cpp +++ b/torchvision/csrc/cpu/roi_pool_kernel.cpp @@ -1,6 +1,7 @@ #include -#include "roi_pool_kernel.h" +#include +#include namespace vision { namespace ops { @@ -124,9 +125,7 @@ void roi_pool_backward_kernel_impl( } // num_rois } -} // namespace - -std::tuple roi_pool_forward_cpu( +std::tuple roi_pool_forward_kernel( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -137,7 +136,7 @@ std::tuple roi_pool_forward_cpu( at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "roi_pool_forward_cpu"; + at::CheckedFrom c = "roi_pool_forward_kernel"; at::checkAllSameType(c, {input_t, rois_t}); int num_rois = rois.size(0); @@ -157,7 +156,7 @@ std::tuple roi_pool_forward_cpu( auto input_ = input.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "roi_pool_forward_cpu", [&] { + input.scalar_type(), "roi_pool_forward_kernel", [&] { roi_pool_forward_kernel_impl( input_.data_ptr(), spatial_scale, @@ -174,7 +173,7 @@ std::tuple roi_pool_forward_cpu( return std::make_tuple(output, argmax); } -at::Tensor roi_pool_backward_cpu( +at::Tensor roi_pool_backward_kernel( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, @@ -194,7 +193,7 @@ at::Tensor roi_pool_backward_cpu( at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "roi_pool_backward_cpu"; + at::CheckedFrom c = "roi_pool_backward_kernel"; at::checkAllSameType(c, {grad_t, rois_t}); auto num_rois = rois.size(0); @@ -215,7 +214,7 @@ at::Tensor roi_pool_backward_cpu( auto rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "roi_pool_backward_cpu", [&] { + grad.scalar_type(), "roi_pool_backward_kernel", [&] { roi_pool_backward_kernel_impl( grad.data_ptr(), argmax.data_ptr(), @@ -235,5 +234,12 @@ at::Tensor roi_pool_backward_cpu( return grad_input; } +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl("roi_pool", roi_pool_forward_kernel); + m.impl("_roi_pool_backward", roi_pool_backward_kernel); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/cpu/roi_pool_kernel.h b/torchvision/csrc/cpu/roi_pool_kernel.h deleted file mode 100644 index 33d029cf31a..00000000000 --- a/torchvision/csrc/cpu/roi_pool_kernel.h +++ /dev/null @@ -1,29 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API std::tuple roi_pool_forward_cpu( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -VISION_API at::Tensor roi_pool_backward_cpu( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/cuda/deform_conv2d_kernel.cu b/torchvision/csrc/cuda/deform_conv2d_kernel.cu index e530710863c..99ccf9133fd 100644 --- a/torchvision/csrc/cuda/deform_conv2d_kernel.cu +++ b/torchvision/csrc/cuda/deform_conv2d_kernel.cu @@ -66,12 +66,13 @@ // modified from // https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp +#include #include #include +#include #include #include "cuda_helpers.h" -#include "deform_conv2d_kernel.h" namespace vision { namespace ops { @@ -896,9 +897,7 @@ at::Tensor backward_gradient_parameters( return grad_weight; } -} // namespace - -at::Tensor deform_conv2d_forward_cuda( +at::Tensor deform_conv2d_forward_kernel( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& offset, @@ -1115,7 +1114,7 @@ at::Tensor deform_conv2d_forward_cuda( } std::tuple -deform_conv2d_backward_cuda( +deform_conv2d_backward_kernel( const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& weight, @@ -1187,5 +1186,12 @@ deform_conv2d_backward_cuda( grad_input, grad_weight, grad_offset, grad_mask, grad_bias); } +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { + m.impl("deform_conv2d", deform_conv2d_forward_kernel); + m.impl("_deform_conv2d_backward", deform_conv2d_backward_kernel); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/cuda/deform_conv2d_kernel.h b/torchvision/csrc/cuda/deform_conv2d_kernel.h deleted file mode 100644 index b2e3dc3f17f..00000000000 --- a/torchvision/csrc/cuda/deform_conv2d_kernel.h +++ /dev/null @@ -1,45 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API at::Tensor deform_conv2d_forward_cuda( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - bool use_mask); - -VISION_API std:: - tuple - deform_conv2d_backward_cuda( - const at::Tensor& grad_out, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - bool use_mask); - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/cuda/nms_kernel.cu b/torchvision/csrc/cuda/nms_kernel.cu index b8d4b3ce0ec..024509727a4 100644 --- a/torchvision/csrc/cuda/nms_kernel.cu +++ b/torchvision/csrc/cuda/nms_kernel.cu @@ -1,8 +1,9 @@ +#include #include #include +#include #include "cuda_helpers.h" -#include "nms_kernel.h" namespace vision { namespace ops { @@ -74,9 +75,7 @@ __global__ void nms_kernel_impl( } } -} // namespace - -at::Tensor nms_cuda( +at::Tensor nms_kernel( const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) { @@ -127,7 +126,7 @@ at::Tensor nms_cuda( cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - dets_sorted.scalar_type(), "nms_cuda", [&] { + dets_sorted.scalar_type(), "nms_kernel", [&] { nms_kernel_impl<<>>( dets_num, iou_threshold, @@ -166,5 +165,11 @@ at::Tensor nms_cuda( .to(order_t.device(), keep.scalar_type())}); } +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { + m.impl("nms", nms_kernel); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/cuda/nms_kernel.h b/torchvision/csrc/cuda/nms_kernel.h deleted file mode 100644 index 0d2c0838437..00000000000 --- a/torchvision/csrc/cuda/nms_kernel.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API at::Tensor nms_cuda( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold); - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/cuda/ps_roi_align_kernel.cu b/torchvision/csrc/cuda/ps_roi_align_kernel.cu index 6b1e729b12d..d001a74e38d 100644 --- a/torchvision/csrc/cuda/ps_roi_align_kernel.cu +++ b/torchvision/csrc/cuda/ps_roi_align_kernel.cu @@ -1,9 +1,10 @@ +#include #include #include +#include #include #include "cuda_helpers.h" -#include "ps_roi_align_kernel.h" namespace vision { namespace ops { @@ -295,9 +296,7 @@ __global__ void ps_roi_align_backward_kernel_impl( } } -} // namespace - -std::tuple ps_roi_align_forward_cuda( +std::tuple ps_roi_align_forward_kernel( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -312,7 +311,7 @@ std::tuple ps_roi_align_forward_cuda( at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "ps_roi_align_forward_cuda"; + at::CheckedFrom c = "ps_roi_align_forward_kernel"; at::checkAllSameGPU(c, {input_t, rois_t}); at::checkAllSameType(c, {input_t, rois_t}); @@ -348,7 +347,7 @@ std::tuple ps_roi_align_forward_cuda( auto input_ = input.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "ps_roi_align_forward_cuda", [&] { + input.scalar_type(), "ps_roi_align_forward_kernel", [&] { ps_roi_align_forward_kernel_impl<<>>( output_size, input_.data_ptr(), @@ -369,7 +368,7 @@ std::tuple ps_roi_align_forward_cuda( return std::make_tuple(output, channel_mapping); } -at::Tensor ps_roi_align_backward_cuda( +at::Tensor ps_roi_align_backward_kernel( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, @@ -390,7 +389,7 @@ at::Tensor ps_roi_align_backward_cuda( at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, channel_mapping_t{channel_mapping, "channel_mapping", 3}; - at::CheckedFrom c = "ps_roi_align_backward_cuda"; + at::CheckedFrom c = "ps_roi_align_backward_kernel"; at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); at::checkAllSameType(c, {grad_t, rois_t}); @@ -417,7 +416,7 @@ at::Tensor ps_roi_align_backward_cuda( auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "ps_roi_align_backward_cuda", [&] { + grad.scalar_type(), "ps_roi_align_backward_kernel", [&] { ps_roi_align_backward_kernel_impl<<>>( grad.numel(), grad_.data_ptr(), @@ -438,5 +437,12 @@ at::Tensor ps_roi_align_backward_cuda( return grad_input; } +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { + m.impl("ps_roi_align", ps_roi_align_forward_kernel); + m.impl("_ps_roi_align_backward", ps_roi_align_backward_kernel); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/cuda/ps_roi_align_kernel.h b/torchvision/csrc/cuda/ps_roi_align_kernel.h deleted file mode 100644 index c40e6fa55b1..00000000000 --- a/torchvision/csrc/cuda/ps_roi_align_kernel.h +++ /dev/null @@ -1,31 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API std::tuple ps_roi_align_forward_cuda( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio); - -VISION_API at::Tensor ps_roi_align_backward_cuda( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/cuda/ps_roi_pool_kernel.cu b/torchvision/csrc/cuda/ps_roi_pool_kernel.cu index 91fd25b4bb5..441c202fa2c 100644 --- a/torchvision/csrc/cuda/ps_roi_pool_kernel.cu +++ b/torchvision/csrc/cuda/ps_roi_pool_kernel.cu @@ -1,9 +1,10 @@ +#include #include #include +#include #include #include "cuda_helpers.h" -#include "ps_roi_pool_kernel.h" namespace vision { namespace ops { @@ -136,9 +137,7 @@ __global__ void ps_roi_pool_backward_kernel_impl( } } -} // namespace - -std::tuple ps_roi_pool_forward_cuda( +std::tuple ps_roi_pool_forward_kernel( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -152,7 +151,7 @@ std::tuple ps_roi_pool_forward_cuda( at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "ps_roi_pool_forward_cuda"; + at::CheckedFrom c = "ps_roi_pool_forward_kernel"; at::checkAllSameGPU(c, {input_t, rois_t}); at::checkAllSameType(c, {input_t, rois_t}); @@ -188,7 +187,7 @@ std::tuple ps_roi_pool_forward_cuda( auto input_ = input.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "ps_roi_pool_forward_cuda", [&] { + input.scalar_type(), "ps_roi_pool_forward_kernel", [&] { ps_roi_pool_forward_kernel_impl<<>>( output_size, input_.data_ptr(), @@ -207,7 +206,7 @@ std::tuple ps_roi_pool_forward_cuda( return std::make_tuple(output, channel_mapping); } -at::Tensor ps_roi_pool_backward_cuda( +at::Tensor ps_roi_pool_backward_kernel( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, @@ -227,7 +226,7 @@ at::Tensor ps_roi_pool_backward_cuda( at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, channel_mapping_t{channel_mapping, "channel_mapping", 3}; - at::CheckedFrom c = "ps_roi_pool_backward_cuda"; + at::CheckedFrom c = "ps_roi_pool_backward_kernel"; at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); at::checkAllSameType(c, {grad_t, rois_t}); @@ -254,7 +253,7 @@ at::Tensor ps_roi_pool_backward_cuda( auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "ps_roi_pool_backward_cuda", [&] { + grad.scalar_type(), "ps_roi_pool_backward_kernel", [&] { ps_roi_pool_backward_kernel_impl<<>>( grad.numel(), grad_.data_ptr(), @@ -274,5 +273,12 @@ at::Tensor ps_roi_pool_backward_cuda( return grad_input; } +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { + m.impl("ps_roi_pool", ps_roi_pool_forward_kernel); + m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_kernel); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/cuda/ps_roi_pool_kernel.h b/torchvision/csrc/cuda/ps_roi_pool_kernel.h deleted file mode 100644 index 21015d4693b..00000000000 --- a/torchvision/csrc/cuda/ps_roi_pool_kernel.h +++ /dev/null @@ -1,29 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API std::tuple ps_roi_pool_forward_cuda( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -VISION_API at::Tensor ps_roi_pool_backward_cuda( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/cuda/roi_align_kernel.cu b/torchvision/csrc/cuda/roi_align_kernel.cu index 59388faa6ad..0930e4fbf44 100644 --- a/torchvision/csrc/cuda/roi_align_kernel.cu +++ b/torchvision/csrc/cuda/roi_align_kernel.cu @@ -1,9 +1,10 @@ +#include #include #include +#include #include #include "cuda_helpers.h" -#include "roi_align_kernel.h" namespace vision { namespace ops { @@ -314,9 +315,7 @@ __global__ void roi_align_backward_kernel_impl( } // CUDA_1D_KERNEL_LOOP } -} // namespace - -at::Tensor roi_align_forward_cuda( +at::Tensor roi_align_forward_kernel( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -330,7 +329,7 @@ at::Tensor roi_align_forward_cuda( at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "roi_align_forward_cuda"; + at::CheckedFrom c = "roi_align_forward_kernel"; at::checkAllSameGPU(c, {input_t, rois_t}); at::checkAllSameType(c, {input_t, rois_t}); @@ -359,7 +358,7 @@ at::Tensor roi_align_forward_cuda( auto input_ = input.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "roi_align_forward_cuda", [&] { + input.scalar_type(), "roi_align_forward_kernel", [&] { roi_align_forward_kernel_impl<<>>( output_size, input_.data_ptr(), @@ -378,7 +377,7 @@ at::Tensor roi_align_forward_cuda( return output; } -at::Tensor roi_align_backward_cuda( +at::Tensor roi_align_backward_kernel( const at::Tensor& grad, const at::Tensor& rois, double spatial_scale, @@ -395,7 +394,7 @@ at::Tensor roi_align_backward_cuda( at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "roi_align_backward_cuda"; + at::CheckedFrom c = "roi_align_backward_kernel"; at::checkAllSameGPU(c, {grad_t, rois_t}); at::checkAllSameType(c, {grad_t, rois_t}); @@ -424,7 +423,7 @@ at::Tensor roi_align_backward_cuda( auto rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "roi_align_backward_cuda", [&] { + grad.scalar_type(), "roi_align_backward_kernel", [&] { roi_align_backward_kernel_impl<<>>( grad.numel(), grad.data_ptr(), @@ -447,5 +446,12 @@ at::Tensor roi_align_backward_cuda( return grad_input; } +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { + m.impl("roi_align", roi_align_forward_kernel); + m.impl("_roi_align_backward", roi_align_backward_kernel); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/cuda/roi_align_kernel.h b/torchvision/csrc/cuda/roi_align_kernel.h deleted file mode 100644 index 71096201627..00000000000 --- a/torchvision/csrc/cuda/roi_align_kernel.h +++ /dev/null @@ -1,32 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API at::Tensor roi_align_forward_cuda( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - bool aligned); - -VISION_API at::Tensor roi_align_backward_cuda( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned); - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/cuda/roi_pool_kernel.cu b/torchvision/csrc/cuda/roi_pool_kernel.cu index a96e79c87a9..6084e492903 100644 --- a/torchvision/csrc/cuda/roi_pool_kernel.cu +++ b/torchvision/csrc/cuda/roi_pool_kernel.cu @@ -1,10 +1,11 @@ +#include #include #include #include +#include #include #include "cuda_helpers.h" -#include "roi_pool_kernel.h" namespace vision { namespace ops { @@ -120,9 +121,7 @@ __global__ void roi_pool_backward_kernel_impl( } } -} // namespace - -std::tuple roi_pool_forward_cuda( +std::tuple roi_pool_forward_kernel( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -135,7 +134,7 @@ std::tuple roi_pool_forward_cuda( at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "roi_pool_forward_cuda"; + at::CheckedFrom c = "roi_pool_forward_kernel"; at::checkAllSameGPU(c, {input_t, rois_t}); at::checkAllSameType(c, {input_t, rois_t}); @@ -167,7 +166,7 @@ std::tuple roi_pool_forward_cuda( auto input_ = input.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "roi_pool_forward_cuda", [&] { + input.scalar_type(), "roi_pool_forward_kernel", [&] { roi_pool_forward_kernel_impl<<>>( output_size, input_.data_ptr(), @@ -185,7 +184,7 @@ std::tuple roi_pool_forward_cuda( return std::make_tuple(output, argmax); } -at::Tensor roi_pool_backward_cuda( +at::Tensor roi_pool_backward_kernel( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, @@ -204,7 +203,7 @@ at::Tensor roi_pool_backward_cuda( at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, argmax_t{argmax, "argmax", 3}; - at::CheckedFrom c = "roi_pool_backward_cuda"; + at::CheckedFrom c = "roi_pool_backward_kernel"; at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t}); at::checkAllSameType(c, {grad_t, rois_t}); @@ -235,7 +234,7 @@ at::Tensor roi_pool_backward_cuda( auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "roi_pool_backward_cuda", [&] { + grad.scalar_type(), "roi_pool_backward_kernel", [&] { roi_pool_backward_kernel_impl<<>>( grad.numel(), grad.data_ptr(), @@ -258,5 +257,12 @@ at::Tensor roi_pool_backward_cuda( return grad_input; } +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { + m.impl("roi_pool", roi_pool_forward_kernel); + m.impl("_roi_pool_backward", roi_pool_backward_kernel); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/cuda/roi_pool_kernel.h b/torchvision/csrc/cuda/roi_pool_kernel.h deleted file mode 100644 index 71a649968db..00000000000 --- a/torchvision/csrc/cuda/roi_pool_kernel.h +++ /dev/null @@ -1,29 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API std::tuple roi_pool_forward_cuda( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -VISION_API at::Tensor roi_pool_backward_cuda( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/deform_conv2d.cpp b/torchvision/csrc/deform_conv2d.cpp index e8a416683f2..512c4e02584 100644 --- a/torchvision/csrc/deform_conv2d.cpp +++ b/torchvision/csrc/deform_conv2d.cpp @@ -1,5 +1,7 @@ #include "deform_conv2d.h" -#include + +#include +#include #if defined(WITH_CUDA) || defined(WITH_HIP) #include @@ -77,6 +79,10 @@ at::Tensor deform_conv2d_autocast( use_mask) .to(input.scalar_type()); } + +TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { + m.impl("deform_conv2d", deform_conv2d_autocast); +} #endif std::tuple @@ -118,6 +124,13 @@ _deform_conv2d_backward( use_mask); } +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def( + "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor"); + m.def( + "_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); +} + namespace { class DeformConv2dFunction @@ -365,5 +378,10 @@ deform_conv2d_backward_autograd( return std::make_tuple(result[0], result[1], result[2], result[3], result[4]); } +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl("deform_conv2d", deform_conv2d_autograd); + m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/deform_conv2d.h b/torchvision/csrc/deform_conv2d.h index 85675ee6bf2..d802147b660 100644 --- a/torchvision/csrc/deform_conv2d.h +++ b/torchvision/csrc/deform_conv2d.h @@ -1,37 +1,13 @@ #pragma once -#include "cpu/deform_conv2d_kernel.h" - -#ifdef WITH_CUDA -#include "cuda/deform_conv2d_kernel.h" -#endif -#ifdef WITH_HIP -#include "hip/deform_conv2d_kernel.h" -#endif +#include +#include "macros.h" namespace vision { namespace ops { // C++ Forward -at::Tensor deform_conv2d( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t offset_groups, - bool use_mask); - -// Autocast Forward -#if defined(WITH_CUDA) || defined(WITH_HIP) -at::Tensor deform_conv2d_autocast( +VISION_API at::Tensor deform_conv2d( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& offset, @@ -46,9 +22,9 @@ at::Tensor deform_conv2d_autocast( int64_t groups, int64_t offset_groups, bool use_mask); -#endif // C++ Backward +VISION_API std::tuple _deform_conv2d_backward( const at::Tensor& grad, @@ -67,40 +43,5 @@ _deform_conv2d_backward( int64_t offset_groups, bool use_mask); -// Autograd Forward and Backward -at::Tensor deform_conv2d_autograd( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t offset_groups, - bool use_mask); - -std::tuple -deform_conv2d_backward_autograd( - const at::Tensor& grad, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t offset_groups, - bool use_mask); - } // namespace ops } // namespace vision diff --git a/torchvision/csrc/new_empty_tensor_op.cpp b/torchvision/csrc/new_empty_tensor_op.cpp index 30941d52ef7..e2de544be0a 100644 --- a/torchvision/csrc/new_empty_tensor_op.cpp +++ b/torchvision/csrc/new_empty_tensor_op.cpp @@ -1,5 +1,7 @@ #include "new_empty_tensor_op.h" -#include + +#include +#include namespace vision { namespace ops { @@ -35,5 +37,9 @@ at::Tensor new_empty_tensor( return NewEmptyTensorOp::apply(input, shape)[0]; } +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def("_new_empty_tensor_op", &new_empty_tensor); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/new_empty_tensor_op.h b/torchvision/csrc/new_empty_tensor_op.h index f00cb67b779..f1435517cba 100644 --- a/torchvision/csrc/new_empty_tensor_op.h +++ b/torchvision/csrc/new_empty_tensor_op.h @@ -1,11 +1,12 @@ #pragma once #include +#include "macros.h" namespace vision { namespace ops { -at::Tensor new_empty_tensor( +VISION_API at::Tensor new_empty_tensor( const at::Tensor& input, const c10::List& shape); diff --git a/torchvision/csrc/nms.cpp b/torchvision/csrc/nms.cpp index 2f9dbee9a32..74e0dbf82ef 100644 --- a/torchvision/csrc/nms.cpp +++ b/torchvision/csrc/nms.cpp @@ -1,5 +1,7 @@ #include "nms.h" -#include + +#include +#include #if defined(WITH_CUDA) || defined(WITH_HIP) #include @@ -29,7 +31,15 @@ at::Tensor nms_autocast( at::autocast::cached_cast(at::kFloat, scores), iou_threshold); } + +TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { + m.impl("nms", nms_autocast); +} #endif +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/nms.h b/torchvision/csrc/nms.h index ac7cbc53caf..f7dfafe0454 100644 --- a/torchvision/csrc/nms.h +++ b/torchvision/csrc/nms.h @@ -1,30 +1,16 @@ #pragma once -#include "cpu/nms_kernel.h" - -#ifdef WITH_CUDA -#include "cuda/nms_kernel.h" -#endif -#ifdef WITH_HIP -#include "hip/nms_kernel.h" -#endif +#include +#include "macros.h" namespace vision { namespace ops { // C++ Forward -at::Tensor nms( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold); - -// Autocast Forward -#if defined(WITH_CUDA) || defined(WITH_HIP) -at::Tensor nms_autocast( +VISION_API at::Tensor nms( const at::Tensor& dets, const at::Tensor& scores, double iou_threshold); -#endif } // namespace ops } // namespace vision diff --git a/torchvision/csrc/ps_roi_align.cpp b/torchvision/csrc/ps_roi_align.cpp index 5add21aaeec..b684431f567 100644 --- a/torchvision/csrc/ps_roi_align.cpp +++ b/torchvision/csrc/ps_roi_align.cpp @@ -1,5 +1,7 @@ #include "ps_roi_align.h" -#include + +#include +#include #if defined(WITH_CUDA) || defined(WITH_HIP) #include @@ -43,6 +45,10 @@ std::tuple ps_roi_align_autocast( std::get<0>(result).to(input.scalar_type()), std::get<1>(result).to(input.scalar_type())); } + +TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { + m.impl("ps_roi_align", ps_roi_align_autocast); +} #endif at::Tensor _ps_roi_align_backward( @@ -75,6 +81,13 @@ at::Tensor _ps_roi_align_backward( width); } +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def( + "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"); + m.def( + "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"); +} + namespace { class PSROIAlignFunction @@ -222,5 +235,10 @@ at::Tensor ps_roi_align_backward_autograd( width)[0]; } +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl("ps_roi_align", ps_roi_align_autograd); + m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/ps_roi_align.h b/torchvision/csrc/ps_roi_align.h index c21107df4f4..bd9df43a90e 100644 --- a/torchvision/csrc/ps_roi_align.h +++ b/torchvision/csrc/ps_roi_align.h @@ -1,61 +1,22 @@ #pragma once -#include "cpu/ps_roi_align_kernel.h" - -#ifdef WITH_CUDA -#include "cuda/ps_roi_align_kernel.h" -#endif -#ifdef WITH_HIP -#include "hip/ps_roi_align_kernel.h" -#endif +#include +#include "macros.h" namespace vision { namespace ops { // C++ Forward -std::tuple ps_roi_align( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio); - -// Autocast Forward -#if defined(WITH_CUDA) || defined(WITH_HIP) -std::tuple ps_roi_align_autocast( +VISION_API std::tuple ps_roi_align( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t sampling_ratio); -#endif // C++ Backward -at::Tensor _ps_roi_align_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -// Autograd Forward and Backward -std::tuple ps_roi_align_autograd( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio); - -at::Tensor ps_roi_align_backward_autograd( +VISION_API at::Tensor _ps_roi_align_backward( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, diff --git a/torchvision/csrc/ps_roi_pool.cpp b/torchvision/csrc/ps_roi_pool.cpp index 88a733a6369..99d72771750 100644 --- a/torchvision/csrc/ps_roi_pool.cpp +++ b/torchvision/csrc/ps_roi_pool.cpp @@ -1,5 +1,7 @@ #include "ps_roi_pool.h" -#include + +#include +#include #if defined(WITH_CUDA) || defined(WITH_HIP) #include @@ -39,6 +41,10 @@ std::tuple ps_roi_pool_autocast( std::get<0>(result).to(input.scalar_type()), std::get<1>(result).to(input.scalar_type())); } + +TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { + m.impl("ps_roi_pool", ps_roi_pool_autocast); +} #endif at::Tensor _ps_roi_pool_backward( @@ -69,6 +75,13 @@ at::Tensor _ps_roi_pool_backward( width); } +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def( + "ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"); + m.def( + "_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"); +} + namespace { class PSROIPoolFunction : public torch::autograd::Function { @@ -201,5 +214,10 @@ at::Tensor ps_roi_pool_backward_autograd( width)[0]; } +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl("ps_roi_pool", ps_roi_pool_autograd); + m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_autograd); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/ps_roi_pool.h b/torchvision/csrc/ps_roi_pool.h index 20ae17d3ad1..21f3c640e0e 100644 --- a/torchvision/csrc/ps_roi_pool.h +++ b/torchvision/csrc/ps_roi_pool.h @@ -1,57 +1,21 @@ #pragma once -#include "cpu/ps_roi_pool_kernel.h" - -#ifdef WITH_CUDA -#include "cuda/ps_roi_pool_kernel.h" -#endif -#ifdef WITH_HIP -#include "hip/ps_roi_pool_kernel.h" -#endif +#include +#include "macros.h" namespace vision { namespace ops { // C++ Forward -std::tuple ps_roi_pool( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -// Autocast Forward -#if defined(WITH_CUDA) || defined(WITH_HIP) -std::tuple ps_roi_pool_autocast( +VISION_API std::tuple ps_roi_pool( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width); -#endif // C++ Backward -at::Tensor _ps_roi_pool_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -// Autograd Forward and Backward -std::tuple ps_roi_pool_autograd( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -at::Tensor ps_roi_pool_backward_autograd( +VISION_API at::Tensor _ps_roi_pool_backward( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, diff --git a/torchvision/csrc/roi_align.cpp b/torchvision/csrc/roi_align.cpp index 63643a6cb46..4a249c97ded 100644 --- a/torchvision/csrc/roi_align.cpp +++ b/torchvision/csrc/roi_align.cpp @@ -1,5 +1,7 @@ #include "roi_align.h" -#include + +#include +#include #if defined(WITH_CUDA) || defined(WITH_HIP) #include @@ -52,6 +54,10 @@ at::Tensor roi_align_autocast( aligned) .to(input.scalar_type()); } + +TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { + m.impl("roi_align", roi_align_autocast); +} #endif at::Tensor _roi_align_backward( @@ -84,6 +90,13 @@ at::Tensor _roi_align_backward( aligned); } +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def( + "roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor"); + m.def( + "_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor"); +} + namespace { class ROIAlignFunction : public torch::autograd::Function { @@ -231,5 +244,10 @@ at::Tensor roi_align_backward_autograd( aligned)[0]; } +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl("roi_align", roi_align_autograd); + m.impl("_roi_align_backward", roi_align_backward_autograd); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/roi_align.h b/torchvision/csrc/roi_align.h index 1e92c8d2134..9397d0a173e 100644 --- a/torchvision/csrc/roi_align.h +++ b/torchvision/csrc/roi_align.h @@ -1,19 +1,13 @@ #pragma once -#include "cpu/roi_align_kernel.h" - -#ifdef WITH_CUDA -#include "cuda/roi_align_kernel.h" -#endif -#ifdef WITH_HIP -#include "hip/roi_align_kernel.h" -#endif +#include +#include "macros.h" namespace vision { namespace ops { // C++ Forward -at::Tensor roi_align( +VISION_API at::Tensor roi_align( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -22,43 +16,8 @@ at::Tensor roi_align( int64_t sampling_ratio, bool aligned); -// Autocast Forward -#if defined(WITH_CUDA) || defined(WITH_HIP) -at::Tensor roi_align_autocast( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - bool aligned); -#endif - // C++ Backward -at::Tensor _roi_align_backward( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned); - -// Autograd Forward and Backward -at::Tensor roi_align_autograd( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - bool aligned); - -at::Tensor roi_align_backward_autograd( +VISION_API at::Tensor _roi_align_backward( const at::Tensor& grad, const at::Tensor& rois, double spatial_scale, diff --git a/torchvision/csrc/roi_pool.cpp b/torchvision/csrc/roi_pool.cpp index b2948e6dd23..a23ad4f8614 100644 --- a/torchvision/csrc/roi_pool.cpp +++ b/torchvision/csrc/roi_pool.cpp @@ -1,5 +1,7 @@ #include "roi_pool.h" -#include + +#include +#include #if defined(WITH_CUDA) || defined(WITH_HIP) #include @@ -39,6 +41,10 @@ std::tuple roi_pool_autocast( std::get<0>(result).to(input.scalar_type()), std::get<1>(result).to(input.scalar_type())); } + +TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { + m.impl("roi_pool", roi_pool_autocast); +} #endif at::Tensor _roi_pool_backward( @@ -68,6 +74,13 @@ at::Tensor _roi_pool_backward( width); } +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def( + "roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"); + m.def( + "_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"); +} + namespace { class ROIPoolFunction : public torch::autograd::Function { @@ -200,5 +213,10 @@ at::Tensor roi_pool_backward_autograd( width)[0]; } +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl("roi_pool", roi_pool_autograd); + m.impl("_roi_pool_backward", roi_pool_backward_autograd); +} + } // namespace ops } // namespace vision diff --git a/torchvision/csrc/roi_pool.h b/torchvision/csrc/roi_pool.h index ac005914107..92a2670d98c 100644 --- a/torchvision/csrc/roi_pool.h +++ b/torchvision/csrc/roi_pool.h @@ -1,57 +1,21 @@ #pragma once -#include "cpu/roi_pool_kernel.h" - -#ifdef WITH_CUDA -#include "cuda/roi_pool_kernel.h" -#endif -#ifdef WITH_HIP -#include "hip/roi_pool_kernel.h" -#endif +#include +#include "macros.h" namespace vision { namespace ops { // C++ Forward -std::tuple roi_pool( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -// Autocast Forward -#if defined(WITH_CUDA) || defined(WITH_HIP) -std::tuple roi_pool_autocast( +VISION_API std::tuple roi_pool( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width); -#endif // C++ Backward -at::Tensor _roi_pool_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -// Autograd Forward and Backward -std::tuple roi_pool_autograd( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -at::Tensor roi_pool_backward_autograd( +VISION_API at::Tensor _roi_pool_backward( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index 766ecd5ff69..baad319e7c0 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -1,7 +1,7 @@ #include "vision.h" #include -#include +#include #ifdef WITH_CUDA #include @@ -10,14 +10,6 @@ #include #endif -#include "deform_conv2d.h" -#include "new_empty_tensor_op.h" -#include "nms.h" -#include "ps_roi_align.h" -#include "ps_roi_pool.h" -#include "roi_align.h" -#include "roi_pool.h" - // If we are in a Windows environment, we need to define // initialization functions for the _custom_ops extension #ifdef _WIN32 @@ -35,88 +27,8 @@ int64_t cuda_version() { return -1; #endif } -} // namespace vision - -using namespace vision::ops; - -TORCH_LIBRARY(torchvision, m) { - m.def( - "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor"); - m.def( - "_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); - m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"); - m.def( - "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"); - m.def( - "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"); - m.def( - "ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"); - m.def( - "_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"); - m.def( - "roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor"); - m.def( - "_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor"); - m.def( - "roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"); - m.def( - "_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"); - m.def("_cuda_version", &vision::cuda_version); - m.def("_new_empty_tensor_op", &new_empty_tensor); -} - -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl("deform_conv2d", deform_conv2d_forward_cpu); - m.impl("_deform_conv2d_backward", deform_conv2d_backward_cpu); - m.impl("nms", nms_cpu); - m.impl("ps_roi_align", ps_roi_align_forward_cpu); - m.impl("_ps_roi_align_backward", ps_roi_align_backward_cpu); - m.impl("ps_roi_pool", ps_roi_pool_forward_cpu); - m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_cpu); - m.impl("roi_align", roi_align_forward_cpu); - m.impl("_roi_align_backward", roi_align_backward_cpu); - m.impl("roi_pool", roi_pool_forward_cpu); - m.impl("_roi_pool_backward", roi_pool_backward_cpu); -} - -// TODO: Place this in a hypothetical separate torchvision_cuda library -#if defined(WITH_CUDA) || defined(WITH_HIP) -TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { - m.impl("deform_conv2d", deform_conv2d_forward_cuda); - m.impl("_deform_conv2d_backward", deform_conv2d_backward_cuda); - m.impl("nms", nms_cuda); - m.impl("ps_roi_align", ps_roi_align_forward_cuda); - m.impl("_ps_roi_align_backward", ps_roi_align_backward_cuda); - m.impl("ps_roi_pool", ps_roi_pool_forward_cuda); - m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_cuda); - m.impl("roi_align", roi_align_forward_cuda); - m.impl("_roi_align_backward", roi_align_backward_cuda); - m.impl("roi_pool", roi_pool_forward_cuda); - m.impl("_roi_pool_backward", roi_pool_backward_cuda); -} -#endif - -// Autocast only needs to wrap forward pass ops. -#if defined(WITH_CUDA) || defined(WITH_HIP) -TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { - m.impl("deform_conv2d", deform_conv2d_autocast); - m.impl("nms", nms_autocast); - m.impl("ps_roi_align", ps_roi_align_autocast); - m.impl("ps_roi_pool", ps_roi_pool_autocast); - m.impl("roi_align", roi_align_autocast); - m.impl("roi_pool", roi_pool_autocast); -} -#endif -TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl("deform_conv2d", deform_conv2d_autograd); - m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd); - m.impl("ps_roi_align", ps_roi_align_autograd); - m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd); - m.impl("ps_roi_pool", ps_roi_pool_autograd); - m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_autograd); - m.impl("roi_align", roi_align_autograd); - m.impl("_roi_align_backward", roi_align_backward_autograd); - m.impl("roi_pool", roi_pool_autograd); - m.impl("_roi_pool_backward", roi_pool_backward_autograd); +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def("_cuda_version", &cuda_version); } +} // namespace vision