Skip to content

Commit

Permalink
Use header files:
Browse files Browse the repository at this point in the history
- Create header files for kernel implementation and remove definitions from vision_*.h files.
- Eliminate unnecessary headers and ensure all cpp include their headers.
  • Loading branch information
datumbox committed Nov 26, 2020
1 parent 602acb2 commit 9026df0
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 90 deletions.
2 changes: 2 additions & 0 deletions torchvision/csrc/autocast.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

// TODO: Delete this file once none of the methods use it

#if defined(WITH_CUDA) || defined(WITH_HIP)
#include <ATen/autocast_mode.h>
#endif
8 changes: 1 addition & 7 deletions torchvision/csrc/cpu/deform_conv2d_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,7 @@
// modified from
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp

#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <TH/TH.h>

#include <cmath>
#include <iostream>
#include <tuple>
#include "deform_conv2d_cpu.h"

namespace {

Expand Down
39 changes: 39 additions & 0 deletions torchvision/csrc/cpu/deform_conv2d_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#pragma once

#include <ATen/ATen.h>
#include "../macros.h"

VISION_API at::Tensor deform_conv2d_forward_cpu(
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);

VISION_API std::
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
deform_conv2d_backward_cpu(
const at::Tensor& grad_out_param,
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);
35 changes: 1 addition & 34 deletions torchvision/csrc/cpu/vision_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,7 @@
#include <torch/extension.h>
#include "../macros.h"

VISION_API at::Tensor deform_conv2d_forward_cpu(
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);

VISION_API std::
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
deform_conv2d_backward_cpu(
const at::Tensor& grad_out_param,
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);
// TODO: Delete this file once all the methods are gone

VISION_API at::Tensor nms_cpu(
const at::Tensor& dets,
Expand Down
6 changes: 1 addition & 5 deletions torchvision/csrc/cuda/deform_conv2d_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,13 @@
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp

#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh>

#include "deform_conv2d_cuda.h"
#include "cuda_helpers.h"

#include <cmath>
#include <iostream>
#include <tuple>

namespace {

const int kMaxParallelImgs = 32;
Expand Down
39 changes: 39 additions & 0 deletions torchvision/csrc/cuda/deform_conv2d_cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#pragma once

#include <ATen/ATen.h>
#include "../macros.h"

VISION_API at::Tensor deform_conv2d_forward_cuda(
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);

VISION_API std::
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
deform_conv2d_backward_cuda(
const at::Tensor& grad_out_param,
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);
35 changes: 1 addition & 34 deletions torchvision/csrc/cuda/vision_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,7 @@
#include <torch/extension.h>
#include "../macros.h"

VISION_API at::Tensor deform_conv2d_forward_cuda(
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);

VISION_API std::
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
deform_conv2d_backward_cuda(
const at::Tensor& grad_out_param,
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);
// TODO: Delete this file once all the methods are gone

VISION_API at::Tensor nms_cuda(
const at::Tensor& dets,
Expand Down
14 changes: 4 additions & 10 deletions torchvision/csrc/deform_conv2d.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
#pragma once
#include "deform_conv2d.h"
#include <torch/extension.h>

#include "cpu/vision_cpu.h"

#ifdef WITH_CUDA
#include "autocast.h"
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "autocast.h"
#include "hip/vision_cuda.h"
#if defined(WITH_CUDA) || defined(WITH_HIP)
#include <ATen/autocast_mode.h>
#endif

namespace {
Expand Down
64 changes: 64 additions & 0 deletions torchvision/csrc/deform_conv2d.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#pragma once

#include "cpu/deform_conv2d_cpu.h"

#ifdef WITH_CUDA
#include "cuda/deform_conv2d_cuda.h"
#endif
#ifdef WITH_HIP
#include "hip/deform_conv2d_cuda.h"
#endif

// Autocast Registration
#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor deform_conv2d_autocast(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask);
#endif

// Autograd Registration
at::Tensor deform_conv2d_autograd(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask);

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
deform_conv2d_backward_autograd(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask);

0 comments on commit 9026df0

Please sign in to comment.