diff --git a/mmcv/ops/csrc/parrots/ms_deform_attn.cpp b/mmcv/ops/csrc/parrots/ms_deform_attn.cpp new file mode 100644 index 00000000000..85ad26253a2 --- /dev/null +++ b/mmcv/ops/csrc/parrots/ms_deform_attn.cpp @@ -0,0 +1,83 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +Tensor ms_deform_attn_cuda_forward(const Tensor &value, + const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const int im2col_step); + +void ms_deform_attn_cuda_backward( + const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, const Tensor &sampling_loc, + const Tensor &attn_weight, const Tensor &grad_output, + Tensor &grad_value, Tensor &grad_sampling_loc, Tensor &grad_attn_weight, + const int im2col_step); + +#endif + +Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const int im2col_step) { + if (value.type().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(value) + CHECK_CUDA_INPUT(spatial_shapes) + CHECK_CUDA_INPUT(level_start_index) + CHECK_CUDA_INPUT(sampling_loc) + CHECK_CUDA_INPUT(attn_weight) + return ms_deform_attn_cuda_forward(value, spatial_shapes, level_start_index, + sampling_loc, attn_weight, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +void ms_deform_attn_backward(const Tensor &value, + const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const Tensor &grad_output, + Tensor &grad_value, + Tensor &grad_sampling_loc, + Tensor &grad_attn_weight, + const int im2col_step) { + if (value.type().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(value) + CHECK_CUDA_INPUT(spatial_shapes) + CHECK_CUDA_INPUT(level_start_index) + CHECK_CUDA_INPUT(sampling_loc) + CHECK_CUDA_INPUT(attn_weight) + CHECK_CUDA_INPUT(grad_output) + CHECK_CUDA_INPUT(grad_value) + CHECK_CUDA_INPUT(grad_sampling_loc) + CHECK_CUDA_INPUT(grad_attn_weight) + ms_deform_attn_cuda_backward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, + grad_output, grad_value, grad_sampling_loc, grad_attn_weight, + im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } else { + AT_ERROR("Not implemented on the CPU"); + } +} diff --git a/mmcv/ops/csrc/parrots/ms_deform_attn_cuda.cu b/mmcv/ops/csrc/parrots/ms_deform_attn_cuda.cu new file mode 100644 index 00000000000..693131b3829 --- /dev/null +++ b/mmcv/ops/csrc/parrots/ms_deform_attn_cuda.cu @@ -0,0 +1,360 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include +#include +#include + +#include +#include +#include + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, + const int num_heads, const int channels, + const int num_levels, const int num_query, + const int num_point, scalar_t *data_col) { + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, + data_sampling_loc, data_attn_weight, batch_size, spatial_size, + num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +template +void ms_deformable_col2im_cuda( + cudaStream_t stream, const scalar_t *grad_col, const scalar_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, const int num_heads, + const int channels, const int num_levels, const int num_query, + const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) { + const int num_threads = + (channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) { + if ((channels & 1023) == 0) { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, + num_query, num_point, grad_value, grad_sampling_loc, + grad_attn_weight); + } else { + ms_deformable_col2im_gpu_kernel_gm + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + } + } else { + switch (channels) { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + default: + if (channels < 64) { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, + num_query, num_point, grad_value, grad_sampling_loc, + grad_attn_weight); + } else { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, + num_query, num_point, grad_value, grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) { + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), + "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), + "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), + "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), + "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), + "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), + "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), + "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", + batch, im2col_step_); + + auto output = + at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view( + {batch / im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch / im2col_step_; ++n) { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES( + value.type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda( + at::cuda::getCurrentCUDAStream(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), level_start_index.data(), + sampling_loc.data() + + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, + num_point, columns.data()); + })); + } + + output = output.view({batch, num_query, num_heads * channels}); + + return output; +} + +void ms_deform_attn_cuda_backward( + const at::Tensor &value, const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, const at::Tensor &grad_output, + at::Tensor &grad_value, at::Tensor &grad_sampling_loc, + at::Tensor &grad_attn_weight, const int im2col_step) { + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), + "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), + "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), + "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), + "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), + "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), + "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), + "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), + "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", + batch, im2col_step_); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view( + {batch / im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch / im2col_step_; ++n) { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES( + value.type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda( + at::cuda::getCurrentCUDAStream(), grad_output_g.data(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), level_start_index.data(), + sampling_loc.data() + + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, + num_point, + grad_value.data() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data() + + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data() + + n * im2col_step_ * per_attn_weight_size); + })); + } +} diff --git a/mmcv/ops/csrc/parrots/ms_deform_attn_parrots.cpp b/mmcv/ops/csrc/parrots/ms_deform_attn_parrots.cpp new file mode 100644 index 00000000000..cdaf2237ed2 --- /dev/null +++ b/mmcv/ops/csrc/parrots/ms_deform_attn_parrots.cpp @@ -0,0 +1,81 @@ +#include + +#include +#include +#include +using namespace at; +using namespace parrots; + +Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const int im2col_step); + +void ms_deform_attn_backward(const Tensor &value, + const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const Tensor &grad_output, + Tensor &grad_value, + Tensor &grad_sampling_loc, + Tensor &grad_attn_weight, + const int im2col_step); + +void ms_deform_attn_forward_parrots(CudaContext &ctx, const SSElement &attr, + const OperatorBase::in_list_t &ins, + OperatorBase::out_list_t &outs) { + int im2col_step; + SSAttrs(attr) + .get("im2col_step", im2col_step) + .done(); + const auto &value = buildATensor(ctx, ins[0]); + const auto &spatial_shapes = buildATensor(ctx, ins[1]); + const auto &level_start_index = buildATensor(ctx, ins[2]); + const auto &sampling_loc = buildATensor(ctx, ins[3]); + const auto &attn_weight = buildATensor(ctx, ins[4]); + auto out = ms_deform_attn_forward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, + im2col_step); + updateDArray(ctx, out, outs[0]); +} + + + + +void ms_deform_attn_backward_parrots(CudaContext &ctx, const SSElement &attr, + const OperatorBase::in_list_t &ins, + OperatorBase::out_list_t &outs) { + int im2col_step; + SSAttrs(attr) + .get("im2col_step", im2col_step) + .done(); + const auto &value = buildATensor(ctx, ins[0]); + const auto &spatial_shapes = buildATensor(ctx, ins[1]); + const auto &level_start_index = buildATensor(ctx, ins[2]); + const auto &sampling_loc = buildATensor(ctx, ins[3]); + const auto &attn_weight = buildATensor(ctx, ins[4]); + const auto &grad_output = buildATensor(ctx, ins[5]); + auto grad_value = buildATensor(ctx, outs[0]); + auto grad_sampling_loc = buildATensor(ctx, outs[1]); + auto grad_attn_weight = buildATensor(ctx, outs[2]); + ms_deform_attn_backward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, + grad_output, grad_value, grad_sampling_loc, grad_attn_weight, + im2col_step); +} + +PARROTS_EXTENSION_REGISTER(ms_deform_attn_forward) + .attr("im2col_step") + .input(5) + .output(1) + .apply(ms_deform_attn_forward_parrots) + .done(); + +PARROTS_EXTENSION_REGISTER(ms_deform_attn_backward) + .attr("im2col_step") + .input(6) + .output(3) + .apply(ms_deform_attn_backward_parrots) + .done(); diff --git a/mmcv/ops/csrc/pytorch/ms_deform_attn.cpp b/mmcv/ops/csrc/pytorch/ms_deform_attn.cpp index 9bcee5c2430..85ad26253a2 100644 --- a/mmcv/ops/csrc/pytorch/ms_deform_attn.cpp +++ b/mmcv/ops/csrc/pytorch/ms_deform_attn.cpp @@ -19,10 +19,11 @@ Tensor ms_deform_attn_cuda_forward(const Tensor &value, const Tensor &attn_weight, const int im2col_step); -std::vector ms_deform_attn_cuda_backward( +void ms_deform_attn_cuda_backward( const Tensor &value, const Tensor &spatial_shapes, const Tensor &level_start_index, const Tensor &sampling_loc, const Tensor &attn_weight, const Tensor &grad_output, + Tensor &grad_value, Tensor &grad_sampling_loc, Tensor &grad_attn_weight, const int im2col_step); #endif @@ -48,13 +49,16 @@ Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes, AT_ERROR("Not implemented on the CPU"); } -std::vector ms_deform_attn_backward(const Tensor &value, - const Tensor &spatial_shapes, - const Tensor &level_start_index, - const Tensor &sampling_loc, - const Tensor &attn_weight, - const Tensor &grad_output, - const int im2col_step) { +void ms_deform_attn_backward(const Tensor &value, + const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const Tensor &grad_output, + Tensor &grad_value, + Tensor &grad_sampling_loc, + Tensor &grad_attn_weight, + const int im2col_step) { if (value.type().is_cuda()) { #ifdef MMCV_WITH_CUDA CHECK_CUDA_INPUT(value) @@ -63,12 +67,17 @@ std::vector ms_deform_attn_backward(const Tensor &value, CHECK_CUDA_INPUT(sampling_loc) CHECK_CUDA_INPUT(attn_weight) CHECK_CUDA_INPUT(grad_output) - return ms_deform_attn_cuda_backward(value, spatial_shapes, - level_start_index, sampling_loc, - attn_weight, grad_output, im2col_step); + CHECK_CUDA_INPUT(grad_value) + CHECK_CUDA_INPUT(grad_sampling_loc) + CHECK_CUDA_INPUT(grad_attn_weight) + ms_deform_attn_cuda_backward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, + grad_output, grad_value, grad_sampling_loc, grad_attn_weight, + im2col_step); #else AT_ERROR("Not compiled with GPU support"); #endif - } + } else { AT_ERROR("Not implemented on the CPU"); + } } diff --git a/mmcv/ops/csrc/pytorch/ms_deform_attn_cuda.cu b/mmcv/ops/csrc/pytorch/ms_deform_attn_cuda.cu index 1cd67403f01..693131b3829 100644 --- a/mmcv/ops/csrc/pytorch/ms_deform_attn_cuda.cu +++ b/mmcv/ops/csrc/pytorch/ms_deform_attn_cuda.cu @@ -286,11 +286,12 @@ at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value, return output; } -std::vector ms_deform_attn_cuda_backward( +void ms_deform_attn_cuda_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, - const int im2col_step) { + at::Tensor &grad_value, at::Tensor &grad_sampling_loc, + at::Tensor &grad_attn_weight, const int im2col_step) { AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); @@ -328,10 +329,6 @@ std::vector ms_deform_attn_cuda_backward( AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); - auto grad_value = at::zeros_like(value); - auto grad_sampling_loc = at::zeros_like(sampling_loc); - auto grad_attn_weight = at::zeros_like(attn_weight); - const int batch_n = im2col_step_; auto per_value_size = spatial_size * num_heads * channels; auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; @@ -360,6 +357,4 @@ std::vector ms_deform_attn_cuda_backward( n * im2col_step_ * per_attn_weight_size); })); } - - return {grad_value, grad_sampling_loc, grad_attn_weight}; } diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 6e1096b6fba..2be35f278e0 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -92,18 +92,16 @@ void modulated_deform_conv_backward( int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, const bool with_bias); -Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes, - const Tensor &level_start_index, - const Tensor &sampling_loc, - const Tensor &attn_weight, const int im2col_step); - -std::vector ms_deform_attn_backward(const Tensor &value, - const Tensor &spatial_shapes, - const Tensor &level_start_index, - const Tensor &sampling_loc, - const Tensor &attn_weight, - const Tensor &grad_output, - const int im2col_step); +Tensor ms_deform_attn_forward( + const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, const Tensor &sampling_loc, + const Tensor &attn_weight, const int im2col_step); + +void ms_deform_attn_backward( + const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, const Tensor &sampling_loc, + const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value, + Tensor &grad_sampling_loc, Tensor &grad_attn_weight, const int im2col_step); Tensor nms(Tensor boxes, Tensor scores, float iou_threshold, int offset); @@ -445,5 +443,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("value"), py::arg("value_spatial_shapes"), py::arg("value_level_start_index"), py::arg("sampling_locations"), py::arg("attention_weights"), py::arg("grad_output"), - py::arg("im2col_step")); + py::arg("grad_value"), py::arg("grad_sampling_loc"), + py::arg("grad_attn_weight"), py::arg("im2col_step")); } diff --git a/mmcv/ops/multi_scale_deform_attn.py b/mmcv/ops/multi_scale_deform_attn.py index 77919e47ece..db88f82739b 100644 --- a/mmcv/ops/multi_scale_deform_attn.py +++ b/mmcv/ops/multi_scale_deform_attn.py @@ -39,7 +39,7 @@ def forward(ctx, value, value_spatial_shapes, value_level_start_index, value_level_start_index, sampling_locations, attention_weights, - ctx.im2col_step) + im2col_step=ctx.im2col_step) ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) @@ -60,15 +60,21 @@ def backward(ctx, grad_output): """ value, value_spatial_shapes, value_level_start_index,\ sampling_locations, attention_weights = ctx.saved_tensors - grad_value, grad_sampling_loc, grad_attn_weight = \ - ext_module.ms_deform_attn_backward( - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - grad_output, - ctx.im2col_step) + grad_value = torch.zeros_like(value) + grad_sampling_loc = torch.zeros_like(sampling_locations) + grad_attn_weight = torch.zeros_like(attention_weights) + + ext_module.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output.contiguous(), + grad_value, + grad_sampling_loc, + grad_attn_weight, + im2col_step=ctx.im2col_step) return grad_value, None, None, \ grad_sampling_loc, grad_attn_weight, None diff --git a/mmcv/utils/ext_loader.py b/mmcv/utils/ext_loader.py index 826e70bb166..18fcc64fced 100644 --- a/mmcv/utils/ext_loader.py +++ b/mmcv/utils/ext_loader.py @@ -19,19 +19,31 @@ def load_ext(name, funcs): 'nms', 'softnms', 'nms_match', 'nms_rotated', 'top_pool_forward', 'top_pool_backward', 'bottom_pool_forward', 'bottom_pool_backward', 'left_pool_forward', 'left_pool_backward', 'right_pool_forward', - 'right_pool_backward', 'fused_bias_leakyrelu', 'upfirdn2d' + 'right_pool_backward', 'fused_bias_leakyrelu', 'upfirdn2d', + 'ms_deform_attn_forward', ] + def get_fake_func(name): + def fake_func(*args, **kwargs): + raise RuntimeError('{} is not supported in parrots now'.format( + name)) + return fake_func + def load_ext(name, funcs): ExtModule = namedtuple('ExtModule', funcs) ext_list = [] lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) for fun in funcs: - if fun in has_return_value_ops: - ext_list.append(extension.load(fun, name, lib_dir=lib_root).op) + try: + ext_fun = extension.load(fun, name, lib_dir=lib_root) + except Exception: + ext_fun = get_fake_func(fun) + ext_list.append(ext_fun) else: - ext_list.append( - extension.load(fun, name, lib_dir=lib_root).op_) + if fun in has_return_value_ops: + ext_list.append(ext_fun.op) + else: + ext_list.append(ext_fun.op_) return ExtModule(*ext_list) diff --git a/tests/test_ops/test_ms_deformable_attn.py b/tests/test_ops/test_ms_deformable_attn.py index 39d371fcb37..92a2f3667ec 100644 --- a/tests/test_ops/test_ms_deformable_attn.py +++ b/tests/test_ops/test_ms_deformable_attn.py @@ -1,10 +1,16 @@ import pytest import torch -from torch.autograd import gradcheck from mmcv.ops.multi_scale_deform_attn import ( MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch) +_USING_PARROTS = True +try: + from parrots.autograd import gradcheck +except ImportError: + from torch.autograd import gradcheck + _USING_PARROTS = False + def test_forward_multi_scale_deformable_attn_pytorch(): N, M, D = 1, 2, 2 @@ -118,8 +124,14 @@ def test_gradient_numerical(channels, value.requires_grad = grad_value sampling_locations.requires_grad = grad_sampling_loc attention_weights.requires_grad = grad_attn_weight - - assert gradcheck( - func, - (value.double(), shapes, level_start_index, - sampling_locations.double(), attention_weights.double(), im2col_step)) + if _USING_PARROTS: + assert gradcheck( + func, (value.double(), shapes, level_start_index, + sampling_locations.double(), attention_weights.double(), + im2col_step), + no_grads=[shapes, level_start_index]) + else: + assert gradcheck( + func, (value.double(), shapes, level_start_index, + sampling_locations.double(), attention_weights.double(), + im2col_step))