Skip to content

Commit

Permalink
Use header files:
Browse files Browse the repository at this point in the history
- Create header files for kernel implementation and remove definitions from vision_*.h files.
- Eliminate unnecessary headers and ensure all cpp include their headers.
- Move internal implementations in detail namespaces.
  • Loading branch information
datumbox committed Nov 26, 2020
1 parent 602acb2 commit 77cbf7c
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 111 deletions.
26 changes: 10 additions & 16 deletions torchvision/csrc/cpu/deform_conv2d_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,9 @@
// modified from
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp

#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <TH/TH.h>
#include "deform_conv2d_cpu.h"

#include <cmath>
#include <iostream>
#include <tuple>

namespace {
namespace detail {

const int kMaxParallelImgs = 32;

Expand Down Expand Up @@ -851,7 +845,7 @@ at::Tensor backward_gradient_parameters(
return grad_weight;
}

} // namespace
} // namespace detail

at::Tensor deform_conv2d_forward_cpu(
const at::Tensor& input_param,
Expand Down Expand Up @@ -885,8 +879,8 @@ at::Tensor deform_conv2d_forward_cpu(
int in_h = input.size(2);
int in_w = input.size(3);

int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
int n_parallel_imgs = detail::get_greatest_divisor_below_bound(
batch_sz, detail::kMaxParallelImgs);

// Unpack shapes and args
int out_channels = weight.size(0);
Expand Down Expand Up @@ -1015,7 +1009,7 @@ at::Tensor deform_conv2d_forward_cpu(
{n_in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w},
input.options());
for (int b = 0; b < batch_sz / n_parallel_imgs; b++) {
deformable_im2col(
detail::deformable_im2col(
input[b],
offset[b],
mask[b],
Expand Down Expand Up @@ -1086,10 +1080,10 @@ deform_conv2d_backward_cpu(
at::Tensor bias = bias_param.contiguous();

const int batch_sz = input.size(0);
const int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
const int n_parallel_imgs = detail::get_greatest_divisor_below_bound(
batch_sz, detail::kMaxParallelImgs);

auto grad_input_and_offset_and_mask = backward_gradient_inputs(
auto grad_input_and_offset_and_mask = detail::backward_gradient_inputs(
input,
weight,
offset,
Expand All @@ -1110,7 +1104,7 @@ deform_conv2d_backward_cpu(
auto grad_offset = std::get<1>(grad_input_and_offset_and_mask);
auto grad_mask = std::get<2>(grad_input_and_offset_and_mask);

auto grad_weight = backward_gradient_parameters(
auto grad_weight = detail::backward_gradient_parameters(
input,
weight,
offset,
Expand Down
39 changes: 39 additions & 0 deletions torchvision/csrc/cpu/deform_conv2d_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#pragma once

#include <ATen/ATen.h>
#include "../macros.h"

VISION_API at::Tensor deform_conv2d_forward_cpu(
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);

VISION_API std::
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
deform_conv2d_backward_cpu(
const at::Tensor& grad_out_param,
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);
35 changes: 1 addition & 34 deletions torchvision/csrc/cpu/vision_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,7 @@
#include <torch/extension.h>
#include "../macros.h"

VISION_API at::Tensor deform_conv2d_forward_cpu(
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);

VISION_API std::
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
deform_conv2d_backward_cpu(
const at::Tensor& grad_out_param,
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);
// TODO: Delete this file once all the methods are gone

VISION_API at::Tensor nms_cpu(
const at::Tensor& dets,
Expand Down
22 changes: 9 additions & 13 deletions torchvision/csrc/cuda/deform_conv2d_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,14 @@
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp

#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh>

#include "deform_conv2d_cuda.h"
#include "cuda_helpers.h"

#include <cmath>
#include <iostream>
#include <tuple>

namespace {
namespace detail {

const int kMaxParallelImgs = 32;

Expand Down Expand Up @@ -898,7 +894,7 @@ at::Tensor backward_gradient_parameters(
return grad_weight;
}

} // namespace
} // namespace detail

at::Tensor deform_conv2d_forward_cuda(
const at::Tensor& input_param,
Expand Down Expand Up @@ -935,7 +931,7 @@ at::Tensor deform_conv2d_forward_cuda(
int in_w = input.size(3);

int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
detail::get_greatest_divisor_below_bound(batch_sz, detail::kMaxParallelImgs);

int out_channels = weight.size(0);
int weight_h = weight.size(2);
Expand Down Expand Up @@ -1063,7 +1059,7 @@ at::Tensor deform_conv2d_forward_cuda(
{in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w},
input.options());
for (int b = 0; b < batch_sz / n_parallel_imgs; b++) {
deformable_im2col(
detail::deformable_im2col(
input[b],
offset[b],
mask[b],
Expand Down Expand Up @@ -1134,10 +1130,10 @@ deform_conv2d_backward_cuda(
at::Tensor bias = bias_param.contiguous();

const int batch_sz = input.size(0);
const int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
const int n_parallel_imgs = detail::get_greatest_divisor_below_bound(
batch_sz, detail::kMaxParallelImgs);

auto grad_input_and_offset_and_mask = backward_gradient_inputs(
auto grad_input_and_offset_and_mask = detail::backward_gradient_inputs(
input,
weight,
offset,
Expand All @@ -1158,7 +1154,7 @@ deform_conv2d_backward_cuda(
auto grad_offset = std::get<1>(grad_input_and_offset_and_mask);
auto grad_mask = std::get<2>(grad_input_and_offset_and_mask);

auto grad_weight = backward_gradient_parameters(
auto grad_weight = detail::backward_gradient_parameters(
input,
weight,
offset,
Expand Down
39 changes: 39 additions & 0 deletions torchvision/csrc/cuda/deform_conv2d_cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#pragma once

#include <ATen/ATen.h>
#include "../macros.h"

VISION_API at::Tensor deform_conv2d_forward_cuda(
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);

VISION_API std::
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
deform_conv2d_backward_cuda(
const at::Tensor& grad_out_param,
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);
35 changes: 1 addition & 34 deletions torchvision/csrc/cuda/vision_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,7 @@
#include <torch/extension.h>
#include "../macros.h"

VISION_API at::Tensor deform_conv2d_forward_cuda(
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);

VISION_API std::
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
deform_conv2d_backward_cuda(
const at::Tensor& grad_out_param,
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);
// TODO: Delete this file once all the methods are gone

VISION_API at::Tensor nms_cuda(
const at::Tensor& dets,
Expand Down
21 changes: 7 additions & 14 deletions torchvision/csrc/deform_conv2d.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
#pragma once

#include "cpu/vision_cpu.h"

#ifdef WITH_CUDA
#include "autocast.h"
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "deform_conv2d.h"
#include <torch/extension.h>
#include "autocast.h"
#include "hip/vision_cuda.h"
#endif

namespace {
namespace detail {

at::Tensor deform_conv2d(
const at::Tensor& input,
Expand Down Expand Up @@ -261,7 +254,7 @@ class DeformConv2dBackwardFunction
}
};

} // namespace
} // namespace detail

#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor deform_conv2d_autocast(
Expand All @@ -280,7 +273,7 @@ at::Tensor deform_conv2d_autocast(
int64_t offset_groups,
bool use_mask) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return deform_conv2d(
return detail::deform_conv2d(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, weight),
at::autocast::cached_cast(at::kFloat, offset),
Expand Down Expand Up @@ -314,7 +307,7 @@ at::Tensor deform_conv2d_autograd(
int64_t groups,
int64_t offset_groups,
bool use_mask) {
return DeformConv2dFunction::apply(
return detail::DeformConv2dFunction::apply(
input,
weight,
offset,
Expand Down Expand Up @@ -348,7 +341,7 @@ deform_conv2d_backward_autograd(
int64_t groups,
int64_t offset_groups,
bool use_mask) {
auto result = DeformConv2dBackwardFunction::apply(
auto result = detail::DeformConv2dBackwardFunction::apply(
grad,
input,
weight,
Expand Down
Loading

0 comments on commit 77cbf7c

Please sign in to comment.