From f3c5a2edfe8f306d6b2bac69c584084c0b86629f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 30 Nov 2020 19:06:55 +0000 Subject: [PATCH] Registering operators in their files. --- torchvision/csrc/cpu/deform_conv2d_kernel.cpp | 7 +++++++ torchvision/csrc/cuda/deform_conv2d_kernel.cu | 7 ++++++- torchvision/csrc/deform_conv2d.cpp | 9 +++++++++ torchvision/csrc/vision.cpp | 8 -------- 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/torchvision/csrc/cpu/deform_conv2d_kernel.cpp b/torchvision/csrc/cpu/deform_conv2d_kernel.cpp index f593e880b3b..a3bd26c82a3 100644 --- a/torchvision/csrc/cpu/deform_conv2d_kernel.cpp +++ b/torchvision/csrc/cpu/deform_conv2d_kernel.cpp @@ -66,6 +66,8 @@ // modified from // https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp +#include + #include "deform_conv2d_kernel.h" namespace { @@ -1137,3 +1139,8 @@ deform_conv2d_backward_cpu( return std::make_tuple( grad_input, grad_weight, grad_offset, grad_mask, grad_bias); } + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl("deform_conv2d", deform_conv2d_forward_cpu); + m.impl("_deform_conv2d_backward", deform_conv2d_backward_cpu); +} diff --git a/torchvision/csrc/cuda/deform_conv2d_kernel.cu b/torchvision/csrc/cuda/deform_conv2d_kernel.cu index 6edaa9c73af..f6e6ceee709 100644 --- a/torchvision/csrc/cuda/deform_conv2d_kernel.cu +++ b/torchvision/csrc/cuda/deform_conv2d_kernel.cu @@ -66,10 +66,10 @@ // modified from // https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp -#include #include #include #include +#include #include "cuda_helpers.h" #include "deform_conv2d_kernel.h" @@ -1188,3 +1188,8 @@ deform_conv2d_backward_cuda( return std::make_tuple( grad_input, grad_weight, grad_offset, grad_mask, grad_bias); } + +TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { + m.impl("deform_conv2d", deform_conv2d_forward_cuda); + m.impl("_deform_conv2d_backward", deform_conv2d_backward_cuda); +} diff --git a/torchvision/csrc/deform_conv2d.cpp b/torchvision/csrc/deform_conv2d.cpp index 74ba630537a..66b91986c2c 100644 --- a/torchvision/csrc/deform_conv2d.cpp +++ b/torchvision/csrc/deform_conv2d.cpp @@ -74,6 +74,10 @@ at::Tensor deform_conv2d_autocast( use_mask) .to(input.scalar_type()); } + +TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { + m.impl("deform_conv2d", deform_conv2d_autocast); +} #endif std::tuple @@ -361,3 +365,8 @@ deform_conv2d_backward_autograd( return std::make_tuple(result[0], result[1], result[2], result[3], result[4]); } + +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl("deform_conv2d", deform_conv2d_autograd); + m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd); +} diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index 2d4e2af0f53..7a3c16ba509 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -12,7 +12,6 @@ #include "PSROIPool.h" #include "ROIAlign.h" #include "ROIPool.h" -#include "deform_conv2d.h" #include "empty_tensor_op.h" #include "nms.h" @@ -62,8 +61,6 @@ TORCH_LIBRARY(torchvision, m) { } TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl("deform_conv2d", deform_conv2d_forward_cpu); - m.impl("_deform_conv2d_backward", deform_conv2d_backward_cpu); m.impl("nms", nms_cpu); m.impl("ps_roi_align", PSROIAlign_forward_cpu); m.impl("_ps_roi_align_backward", PSROIAlign_backward_cpu); @@ -78,8 +75,6 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) { // TODO: Place this in a hypothetical separate torchvision_cuda library #if defined(WITH_CUDA) || defined(WITH_HIP) TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { - m.impl("deform_conv2d", deform_conv2d_forward_cuda); - m.impl("_deform_conv2d_backward", deform_conv2d_backward_cuda); m.impl("nms", nms_cuda); m.impl("ps_roi_align", PSROIAlign_forward_cuda); m.impl("_ps_roi_align_backward", PSROIAlign_backward_cuda); @@ -95,7 +90,6 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { // Autocast only needs to wrap forward pass ops. #if defined(WITH_CUDA) || defined(WITH_HIP) TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { - m.impl("deform_conv2d", deform_conv2d_autocast); m.impl("nms", nms_autocast); m.impl("ps_roi_align", PSROIAlign_autocast); m.impl("ps_roi_pool", PSROIPool_autocast); @@ -105,8 +99,6 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { #endif TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl("deform_conv2d", deform_conv2d_autograd); - m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd); m.impl("ps_roi_align", PSROIAlign_autograd); m.impl("_ps_roi_align_backward", PSROIAlign_backward_autograd); m.impl("ps_roi_pool", PSROIPool_autograd);