Skip to content

Commit

Permalink
Move autograd implementations on separate files.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Dec 10, 2020
1 parent 7d831a2 commit db2adb2
Show file tree
Hide file tree
Showing 18 changed files with 993 additions and 838 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ include(GNUInstallDirs)
include(CMakePackageConfigHelpers)

set(TVCPP torchvision/csrc)
list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops ${TVCPP}/ops/cpu)
list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops
${TVCPP}/ops/autograd ${TVCPP}/ops/cpu)
if(WITH_CUDA)
list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
endif()
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def get_extensions():

main_file = glob.glob(os.path.join(extensions_dir, '*.cpp')) + glob.glob(os.path.join(extensions_dir, 'ops',
'*.cpp'))
source_cpu = glob.glob(os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp'))
source_cpu = glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) + glob.glob(
os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp'))

is_rocm_pytorch = False
if torch.__version__ >= '1.5':
Expand Down
262 changes: 262 additions & 0 deletions torchvision/csrc/ops/autograd/deform_conv2d_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
#include "../deform_conv2d.h"

#include <torch/autograd.h>
#include <torch/types.h>

namespace vision {
namespace ops {

namespace {

class DeformConv2dFunction
: public torch::autograd::Function<DeformConv2dFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask) {
at::AutoNonVariableTypeMode g;
auto output = deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);

ctx->save_for_backward({input, weight, offset, mask, bias});
ctx->saved_data["stride_h"] = stride_h;
ctx->saved_data["stride_w"] = stride_w;
ctx->saved_data["pad_h"] = pad_h;
ctx->saved_data["pad_w"] = pad_w;
ctx->saved_data["dilation_h"] = dilation_h;
ctx->saved_data["dilation_w"] = dilation_w;
ctx->saved_data["groups"] = groups;
ctx->saved_data["offset_groups"] = offset_groups;
ctx->saved_data["use_mask"] = use_mask;

return {
output,
};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
auto saved = ctx->get_saved_variables();
auto input = saved[0];
auto weight = saved[1];
auto offset = saved[2];
auto mask = saved[3];
auto bias = saved[4];

auto stride_h = ctx->saved_data["stride_h"].toInt();
auto stride_w = ctx->saved_data["stride_w"].toInt();
auto pad_h = ctx->saved_data["pad_h"].toInt();
auto pad_w = ctx->saved_data["pad_w"].toInt();
auto dilation_h = ctx->saved_data["dilation_h"].toInt();
auto dilation_w = ctx->saved_data["dilation_w"].toInt();
auto groups = ctx->saved_data["groups"].toInt();
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
auto use_mask = ctx->saved_data["use_mask"].toBool();

auto grads = detail::_deform_conv2d_backward(
grad_output[0],
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
auto grad_input = std::get<0>(grads);
auto grad_weight = std::get<1>(grads);
auto grad_offset = std::get<2>(grads);
auto grad_mask = std::get<3>(grads);
auto grad_bias = std::get<4>(grads);

return {
grad_input,
grad_weight,
grad_offset,
grad_mask,
grad_bias,
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
};
}
};

// TODO: There should be an easier way to do this
class DeformConv2dBackwardFunction
: public torch::autograd::Function<DeformConv2dBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& grad,
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask) {
at::AutoNonVariableTypeMode g;
auto result = detail::_deform_conv2d_backward(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);

auto grad_input = std::get<0>(result);
auto grad_weight = std::get<1>(result);
auto grad_offset = std::get<2>(result);
auto grad_mask = std::get<3>(result);
auto grad_bias = std::get<4>(result);

return {
grad_input,
grad_weight,
grad_offset,
grad_mask,
grad_bias,
};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on deform_conv2d not supported");
}
};

at::Tensor deform_conv2d_autograd(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask) {
return DeformConv2dFunction::apply(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask)[0];
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
deform_conv2d_backward_autograd(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask) {
auto result = DeformConv2dBackwardFunction::apply(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);

return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("deform_conv2d", deform_conv2d_autograd);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
}

} // namespace ops
} // namespace vision
Loading

0 comments on commit db2adb2

Please sign in to comment.