Skip to content

Commit

Permalink
Registering operators in their files.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Nov 30, 2020
1 parent da80ce1 commit f3c5a2e
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 9 deletions.
7 changes: 7 additions & 0 deletions torchvision/csrc/cpu/deform_conv2d_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
// modified from
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp

#include <torch/script.h>

#include "deform_conv2d_kernel.h"

namespace {
Expand Down Expand Up @@ -1137,3 +1139,8 @@ deform_conv2d_backward_cpu(
return std::make_tuple(
grad_input, grad_weight, grad_offset, grad_mask, grad_bias);
}

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("deform_conv2d", deform_conv2d_forward_cpu);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_cpu);
}
7 changes: 6 additions & 1 deletion torchvision/csrc/cuda/deform_conv2d_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@
// modified from
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh>
#include <torch/script.h>

#include "cuda_helpers.h"
#include "deform_conv2d_kernel.h"
Expand Down Expand Up @@ -1188,3 +1188,8 @@ deform_conv2d_backward_cuda(
return std::make_tuple(
grad_input, grad_weight, grad_offset, grad_mask, grad_bias);
}

TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("deform_conv2d", deform_conv2d_forward_cuda);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_cuda);
}
9 changes: 9 additions & 0 deletions torchvision/csrc/deform_conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ at::Tensor deform_conv2d_autocast(
use_mask)
.to(input.scalar_type());
}

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("deform_conv2d", deform_conv2d_autocast);
}
#endif

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
Expand Down Expand Up @@ -361,3 +365,8 @@ deform_conv2d_backward_autograd(

return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
}

TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("deform_conv2d", deform_conv2d_autograd);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
}
8 changes: 0 additions & 8 deletions torchvision/csrc/vision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "PSROIPool.h"
#include "ROIAlign.h"
#include "ROIPool.h"
#include "deform_conv2d.h"
#include "empty_tensor_op.h"
#include "nms.h"

Expand Down Expand Up @@ -62,8 +61,6 @@ TORCH_LIBRARY(torchvision, m) {
}

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("deform_conv2d", deform_conv2d_forward_cpu);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_cpu);
m.impl("nms", nms_cpu);
m.impl("ps_roi_align", PSROIAlign_forward_cpu);
m.impl("_ps_roi_align_backward", PSROIAlign_backward_cpu);
Expand All @@ -78,8 +75,6 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
// TODO: Place this in a hypothetical separate torchvision_cuda library
#if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("deform_conv2d", deform_conv2d_forward_cuda);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_cuda);
m.impl("nms", nms_cuda);
m.impl("ps_roi_align", PSROIAlign_forward_cuda);
m.impl("_ps_roi_align_backward", PSROIAlign_backward_cuda);
Expand All @@ -95,7 +90,6 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
// Autocast only needs to wrap forward pass ops.
#if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("deform_conv2d", deform_conv2d_autocast);
m.impl("nms", nms_autocast);
m.impl("ps_roi_align", PSROIAlign_autocast);
m.impl("ps_roi_pool", PSROIPool_autocast);
Expand All @@ -105,8 +99,6 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
#endif

TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("deform_conv2d", deform_conv2d_autograd);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
m.impl("ps_roi_align", PSROIAlign_autograd);
m.impl("_ps_roi_align_backward", PSROIAlign_backward_autograd);
m.impl("ps_roi_pool", PSROIPool_autograd);
Expand Down

0 comments on commit f3c5a2e

Please sign in to comment.