Skip to content

Commit

Permalink
add ms_deformable_attn in parrots
Browse files Browse the repository at this point in the history
  • Loading branch information
luopeichao committed May 21, 2021
1 parent 1a66977 commit bec2e15
Show file tree
Hide file tree
Showing 9 changed files with 611 additions and 54 deletions.
83 changes: 83 additions & 0 deletions mmcv/ops/csrc/parrots/ms_deform_attn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/

#include "pytorch_cpp_helper.hpp"

#ifdef MMCV_WITH_CUDA
Tensor ms_deform_attn_cuda_forward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const int im2col_step);

void ms_deform_attn_cuda_backward(
const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index, const Tensor &sampling_loc,
const Tensor &attn_weight, const Tensor &grad_output,
Tensor &grad_value, Tensor &grad_sampling_loc, Tensor &grad_attn_weight,
const int im2col_step);

#endif

Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const int im2col_step) {
if (value.type().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(value)
CHECK_CUDA_INPUT(spatial_shapes)
CHECK_CUDA_INPUT(level_start_index)
CHECK_CUDA_INPUT(sampling_loc)
CHECK_CUDA_INPUT(attn_weight)
return ms_deform_attn_cuda_forward(value, spatial_shapes, level_start_index,
sampling_loc, attn_weight, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}

void ms_deform_attn_backward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output,
Tensor &grad_value,
Tensor &grad_sampling_loc,
Tensor &grad_attn_weight,
const int im2col_step) {
if (value.type().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(value)
CHECK_CUDA_INPUT(spatial_shapes)
CHECK_CUDA_INPUT(level_start_index)
CHECK_CUDA_INPUT(sampling_loc)
CHECK_CUDA_INPUT(attn_weight)
CHECK_CUDA_INPUT(grad_output)
CHECK_CUDA_INPUT(grad_value)
CHECK_CUDA_INPUT(grad_sampling_loc)
CHECK_CUDA_INPUT(grad_attn_weight)
ms_deform_attn_cuda_backward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight,
grad_output, grad_value, grad_sampling_loc, grad_attn_weight,
im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
} else {
AT_ERROR("Not implemented on the CPU");
}
}
360 changes: 360 additions & 0 deletions mmcv/ops/csrc/parrots/ms_deform_attn_cuda.cu

Large diffs are not rendered by default.

81 changes: 81 additions & 0 deletions mmcv/ops/csrc/parrots/ms_deform_attn_parrots.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#include <torch/extension.h>

#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
using namespace at;
using namespace parrots;

Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const int im2col_step);

void ms_deform_attn_backward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output,
Tensor &grad_value,
Tensor &grad_sampling_loc,
Tensor &grad_attn_weight,
const int im2col_step);

void ms_deform_attn_forward_parrots(CudaContext &ctx, const SSElement &attr,
const OperatorBase::in_list_t &ins,
OperatorBase::out_list_t &outs) {
int im2col_step;
SSAttrs(attr)
.get<int>("im2col_step", im2col_step)
.done();
const auto &value = buildATensor(ctx, ins[0]);
const auto &spatial_shapes = buildATensor(ctx, ins[1]);
const auto &level_start_index = buildATensor(ctx, ins[2]);
const auto &sampling_loc = buildATensor(ctx, ins[3]);
const auto &attn_weight = buildATensor(ctx, ins[4]);
auto out = ms_deform_attn_forward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight,
im2col_step);
updateDArray(ctx, out, outs[0]);
}




void ms_deform_attn_backward_parrots(CudaContext &ctx, const SSElement &attr,
const OperatorBase::in_list_t &ins,
OperatorBase::out_list_t &outs) {
int im2col_step;
SSAttrs(attr)
.get<int>("im2col_step", im2col_step)
.done();
const auto &value = buildATensor(ctx, ins[0]);
const auto &spatial_shapes = buildATensor(ctx, ins[1]);
const auto &level_start_index = buildATensor(ctx, ins[2]);
const auto &sampling_loc = buildATensor(ctx, ins[3]);
const auto &attn_weight = buildATensor(ctx, ins[4]);
const auto &grad_output = buildATensor(ctx, ins[5]);
auto grad_value = buildATensor(ctx, outs[0]);
auto grad_sampling_loc = buildATensor(ctx, outs[1]);
auto grad_attn_weight = buildATensor(ctx, outs[2]);
ms_deform_attn_backward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight,
grad_output, grad_value, grad_sampling_loc, grad_attn_weight,
im2col_step);
}

PARROTS_EXTENSION_REGISTER(ms_deform_attn_forward)
.attr("im2col_step")
.input(5)
.output(1)
.apply(ms_deform_attn_forward_parrots)
.done();

PARROTS_EXTENSION_REGISTER(ms_deform_attn_backward)
.attr("im2col_step")
.input(6)
.output(3)
.apply(ms_deform_attn_backward_parrots)
.done();
33 changes: 21 additions & 12 deletions mmcv/ops/csrc/pytorch/ms_deform_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ Tensor ms_deform_attn_cuda_forward(const Tensor &value,
const Tensor &attn_weight,
const int im2col_step);

std::vector<Tensor> ms_deform_attn_cuda_backward(
void ms_deform_attn_cuda_backward(
const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index, const Tensor &sampling_loc,
const Tensor &attn_weight, const Tensor &grad_output,
Tensor &grad_value, Tensor &grad_sampling_loc, Tensor &grad_attn_weight,
const int im2col_step);

#endif
Expand All @@ -48,13 +49,16 @@ Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
AT_ERROR("Not implemented on the CPU");
}

std::vector<Tensor> ms_deform_attn_backward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output,
const int im2col_step) {
void ms_deform_attn_backward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output,
Tensor &grad_value,
Tensor &grad_sampling_loc,
Tensor &grad_attn_weight,
const int im2col_step) {
if (value.type().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(value)
Expand All @@ -63,12 +67,17 @@ std::vector<Tensor> ms_deform_attn_backward(const Tensor &value,
CHECK_CUDA_INPUT(sampling_loc)
CHECK_CUDA_INPUT(attn_weight)
CHECK_CUDA_INPUT(grad_output)
return ms_deform_attn_cuda_backward(value, spatial_shapes,
level_start_index, sampling_loc,
attn_weight, grad_output, im2col_step);
CHECK_CUDA_INPUT(grad_value)
CHECK_CUDA_INPUT(grad_sampling_loc)
CHECK_CUDA_INPUT(grad_attn_weight)
ms_deform_attn_cuda_backward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight,
grad_output, grad_value, grad_sampling_loc, grad_attn_weight,
im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
} else {
AT_ERROR("Not implemented on the CPU");
}
}
11 changes: 3 additions & 8 deletions mmcv/ops/csrc/pytorch/ms_deform_attn_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,12 @@ at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
return output;
}

std::vector<at::Tensor> ms_deform_attn_cuda_backward(
void ms_deform_attn_cuda_backward(
const at::Tensor &value, const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index, const at::Tensor &sampling_loc,
const at::Tensor &attn_weight, const at::Tensor &grad_output,
const int im2col_step) {
at::Tensor &grad_value, at::Tensor &grad_sampling_loc,
at::Tensor &grad_attn_weight, const int im2col_step) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(),
"spatial_shapes tensor has to be contiguous");
Expand Down Expand Up @@ -328,10 +329,6 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)",
batch, im2col_step_);

auto grad_value = at::zeros_like(value);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
auto grad_attn_weight = at::zeros_like(attn_weight);

const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
Expand Down Expand Up @@ -360,6 +357,4 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
n * im2col_step_ * per_attn_weight_size);
}));
}

return {grad_value, grad_sampling_loc, grad_attn_weight};
}
25 changes: 12 additions & 13 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,16 @@ void modulated_deform_conv_backward(
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias);

Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight, const int im2col_step);

std::vector<Tensor> ms_deform_attn_backward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output,
const int im2col_step);
Tensor ms_deform_attn_forward(
const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index, const Tensor &sampling_loc,
const Tensor &attn_weight, const int im2col_step);

void ms_deform_attn_backward(
const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index, const Tensor &sampling_loc,
const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value,
Tensor &grad_sampling_loc, Tensor &grad_attn_weight, const int im2col_step);

Tensor nms(Tensor boxes, Tensor scores, float iou_threshold, int offset);

Expand Down Expand Up @@ -445,5 +443,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("value"), py::arg("value_spatial_shapes"),
py::arg("value_level_start_index"), py::arg("sampling_locations"),
py::arg("attention_weights"), py::arg("grad_output"),
py::arg("im2col_step"));
py::arg("grad_value"), py::arg("grad_sampling_loc"),
py::arg("grad_attn_weight"), py::arg("im2col_step"));
}
26 changes: 16 additions & 10 deletions mmcv/ops/multi_scale_deform_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def forward(ctx, value, value_spatial_shapes, value_level_start_index,
value_level_start_index,
sampling_locations,
attention_weights,
ctx.im2col_step)
im2col_step=ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes,
value_level_start_index, sampling_locations,
attention_weights)
Expand All @@ -60,15 +60,21 @@ def backward(ctx, grad_output):
"""
value, value_spatial_shapes, value_level_start_index,\
sampling_locations, attention_weights = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = \
ext_module.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
ctx.im2col_step)
grad_value = torch.zeros_like(value)
grad_sampling_loc = torch.zeros_like(sampling_locations)
grad_attn_weight = torch.zeros_like(attention_weights)

ext_module.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output.contiguous(),
grad_value,
grad_sampling_loc,
grad_attn_weight,
im2col_step=ctx.im2col_step)

return grad_value, None, None, \
grad_sampling_loc, grad_attn_weight, None
Expand Down
22 changes: 17 additions & 5 deletions mmcv/utils/ext_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,31 @@ def load_ext(name, funcs):
'nms', 'softnms', 'nms_match', 'nms_rotated', 'top_pool_forward',
'top_pool_backward', 'bottom_pool_forward', 'bottom_pool_backward',
'left_pool_forward', 'left_pool_backward', 'right_pool_forward',
'right_pool_backward', 'fused_bias_leakyrelu', 'upfirdn2d'
'right_pool_backward', 'fused_bias_leakyrelu', 'upfirdn2d',
'ms_deform_attn_forward',
]

def get_fake_func(name):
def fake_func(*args, **kwargs):
raise RuntimeError('{} is not supported in parrots now'.format(
name))
return fake_func

def load_ext(name, funcs):
ExtModule = namedtuple('ExtModule', funcs)
ext_list = []
lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
for fun in funcs:
if fun in has_return_value_ops:
ext_list.append(extension.load(fun, name, lib_dir=lib_root).op)
try:
ext_fun = extension.load(fun, name, lib_dir=lib_root)
except Exception:
ext_fun = get_fake_func(fun)
ext_list.append(ext_fun)
else:
ext_list.append(
extension.load(fun, name, lib_dir=lib_root).op_)
if fun in has_return_value_ops:
ext_list.append(ext_fun.op)
else:
ext_list.append(ext_fun.op_)
return ExtModule(*ext_list)


Expand Down
Loading

0 comments on commit bec2e15

Please sign in to comment.