From 2b47d7bdb51b02e5533231bff8018f88df23cae0 Mon Sep 17 00:00:00 2001 From: wondervictor Date: Tue, 8 Jun 2021 15:20:37 +0800 Subject: [PATCH 1/4] optimize criss cross attention --- mmcv/ops/csrc/cc_attention_cuda_kernel.cuh | 190 ++++----- mmcv/ops/csrc/pytorch/cc_attention_cuda.cu | 15 +- mmcv/ops/csrc/pytorch/pybind.cpp | 465 +++++++++++---------- 3 files changed, 330 insertions(+), 340 deletions(-) diff --git a/mmcv/ops/csrc/cc_attention_cuda_kernel.cuh b/mmcv/ops/csrc/cc_attention_cuda_kernel.cuh index 0dd9c33c66..a0cce0d83e 100644 --- a/mmcv/ops/csrc/cc_attention_cuda_kernel.cuh +++ b/mmcv/ops/csrc/cc_attention_cuda_kernel.cuh @@ -14,24 +14,22 @@ __global__ void ca_forward_kernel(const T *t, const T *f, T *weight, int num, int y = blockIdx.y * blockDim.y + threadIdx.y; int sp = height * width; int len = height + width - 1; - int z = blockIdx.z; - - if (x < width && y < height && z < height + width - 1) { - for (int batch = 0; batch < num; ++batch) { - for (int plane = 0; plane < chn; ++plane) { - T _t = t[(batch * chn + plane) * sp + y * width + x]; - - if (z < width) { - int i = z; - T _f = f[(batch * chn + plane) * sp + y * width + i]; - weight[(batch * len + i) * sp + y * width + x] += _t * _f; - } else { - int i = z - width; - int j = i < y ? i : i + 1; - - T _f = f[(batch * chn + plane) * sp + j * width + x]; - weight[(batch * len + width + i) * sp + y * width + x] += _t * _f; - } + int z = blockIdx.z % len; + int batch = blockIdx.z / len; + + if (x < width && y < height) { + for (int plane = 0; plane < chn; ++plane) { + T _t = t[(batch * chn + plane) * sp + y*width + x]; + + if (z < width) { + int i = z; + T _f = f[(batch * chn + plane) * sp + y*width + i]; + weight[(batch * len + i) * sp + y*width + x] += _t*_f; + } else { + int i = z - width; + int j = i y ? y : y - 1; + int plane = blockIdx.z % chn; + int batch = blockIdx.z / chn; + + if (x < width && y < height) { + for (int i = 0; i < width; ++i) { + T _dw = dw[(batch * len + x) * sp + y*width + i]; + T _t = t[(batch * chn + plane) * sp + y*width + i]; + df[(batch * chn + plane) * sp + y*width + x] += _dw * _t; + } + for (int i = 0; i < height; ++i) { + if (i == y) continue; + int j = i>y ? y : y-1; - T _dw = dw[(batch * len + width + j) * sp + i * width + x]; - T _t = t[(batch * chn + plane) * sp + i * width + x]; - df[(batch * chn + plane) * sp + y * width + x] += _dw * _t; - } + T _dw = dw[(batch * len + width + j) * sp + i*width + x]; + T _t = t[(batch * chn + plane) * sp + i*width + x]; + df[(batch * chn + plane) * sp + y*width + x] += _dw * _t; } } } @@ -100,24 +96,23 @@ __global__ void ca_map_forward_kernel(const T *weight, const T *g, T *out, int y = blockIdx.y * blockDim.y + threadIdx.y; int sp = height * width; int len = height + width - 1; - int plane = blockIdx.z; - - if (x < width && y < height && plane < chn) { - for (int batch = 0; batch < num; ++batch) { - for (int i = 0; i < width; ++i) { - T _g = g[(batch * chn + plane) * sp + y * width + i]; - T _w = weight[(batch * len + i) * sp + y * width + x]; - out[(batch * chn + plane) * sp + y * width + x] += _g * _w; - } - for (int i = 0; i < height; ++i) { - if (i == y) continue; + int plane = blockIdx.z % chn; + int batch = blockIdx.z / chn; + T res = 0; + if (x < width && y < height) { + for (int i = 0; i < width; ++i) { + T _g = g[(batch * chn + plane) * sp + y * width + i]; + T _w = weight[(batch * len + i) * sp + y * width + x]; + out[(batch * chn + plane) * sp + y * width + x] += _g * _w; + } + for (int i = 0; i < height; ++i) { + if (i == y) continue; - int j = i < y ? i : i - 1; + int j = i < y ? i : i - 1; - T _g = g[(batch * chn + plane) * sp + i * width + x]; - T _w = weight[(batch * len + width + j) * sp + y * width + x]; - out[(batch * chn + plane) * sp + y * width + x] += _g * _w; - } + T _g = g[(batch * chn + plane) * sp + i * width + x]; + T _w = weight[(batch * len + width + j) * sp + y * width + x]; + out[(batch * chn + plane) * sp + y * width + x] += _g * _w; } } } @@ -130,25 +125,23 @@ __global__ void ca_map_backward_kernel_w(const T *dout, const T *weight, int y = blockIdx.y * blockDim.y + threadIdx.y; int sp = height * width; int len = height + width - 1; - int z = blockIdx.z; - if (x < width && y < height && z < height + width - 1) { - for (int batch = 0; batch < num; ++batch) { - for (int plane = 0; plane < chn; ++plane) { - T _dout = dout[(batch * chn + plane) * sp + y * width + x]; - - if (z < width) { - int i = z; - T _g = g[(batch * chn + plane) * sp + y * width + i]; - dw[(batch * len + i) * sp + y * width + x] += _dout * _g; - } else { - int i = z - width; - int j = i < y ? i : i + 1; - - T _g = g[(batch * chn + plane) * sp + j * width + x]; - dw[(batch * len + width + i) * sp + y * width + x] += _dout * _g; - } - } + int z = blockIdx.z % len; + int batch = blockIdx.z / len; + + if (x < width && y < height) { + int widx = (batch * len + z) * sp + y*width + x; + int dout_idx = batch * chn * sp + y * width + x; + int gidx = batch * chn * sp; + if (z < width) { + gidx += y * width + z; + } else { + int j = z - width; + j = j < y ? j : j + 1; + gidx += j * width + x; + } + for(int plane = 0; plane < chn; plane ++){ + dw[widx] += dout[dout_idx + plane * sp] * g[gidx+plane*sp]; } } } @@ -161,25 +154,20 @@ __global__ void ca_map_backward_kernel_g(const T *dout, const T *weight, int y = blockIdx.y * blockDim.y + threadIdx.y; int sp = height * width; int len = height + width - 1; - int plane = blockIdx.z; - - if (x < width && y < height && plane < chn) { - for (int batch = 0; batch < num; ++batch) { - for (int i = 0; i < width; ++i) { - T _dout = dout[(batch * chn + plane) * sp + y * width + i]; - T _w = weight[(batch * len + x) * sp + y * width + i]; - dg[(batch * chn + plane) * sp + y * width + x] += _dout * _w; - } - for (int i = 0; i < height; ++i) { - if (i == y) continue; - int j = i > y ? y : y - 1; + int plane = blockIdx.z % chn; + int batch = blockIdx.z / chn; + int index = (batch * chn + plane) * sp + y*width + x; - T _dout = dout[(batch * chn + plane) * sp + i * width + x]; - T _w = weight[(batch * len + width + j) * sp + i * width + x]; - dg[(batch * chn + plane) * sp + y * width + x] += _dout * _w; - } + if (x < width && y < height) { + for (int i = 0; i < width; ++i) { + dg[index] += dout[(batch * chn + plane) * sp + y*width + i] * weight[(batch * len + x) * sp + y*width + i]; + } + int j = 0; + for (int i = 0; i < height; ++i) { + if (i == y) continue; + j = i > y ? y : y - 1; + dg[index] += dout[(batch * chn + plane) * sp + i * width + x] * weight[(batch * len + width + j) * sp + i * width + x]; } } } - #endif // CC_ATTENTION_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/cc_attention_cuda.cu b/mmcv/ops/csrc/pytorch/cc_attention_cuda.cu index b948d5406a..fd4e7fd128 100644 --- a/mmcv/ops/csrc/pytorch/cc_attention_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cc_attention_cuda.cu @@ -24,8 +24,8 @@ void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f, dim3 threads(32, 32); int d1 = (w + threads.x - 1) / threads.x; int d2 = (h + threads.y - 1) / threads.y; - int d3 = h + w; - dim3 blocks(d1, d2, d3); + int d3 = h + w - 1; + dim3 blocks(d1, d2, d3 * n); AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_forward", [&] { ca_forward_kernel<<>>( @@ -53,7 +53,7 @@ void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t, dim3 threads(32, 32); int d1 = (w + threads.x - 1) / threads.x; int d2 = (h + threads.y - 1) / threads.y; - int d3 = c; + int d3 = c * n; dim3 blocks(d1, d2, d3); AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_backward_kernel_t", [&] { @@ -90,7 +90,7 @@ void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g, dim3 threads(32, 32); int d1 = (w + threads.x - 1) / threads.x; int d2 = (h + threads.y - 1) / threads.y; - int d3 = c; + int d3 = c * n; dim3 blocks(d1, d2, d3); AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_forward", [&] { @@ -119,8 +119,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight, dim3 threads(32, 32); int d1 = (w + threads.x - 1) / threads.x; int d2 = (h + threads.y - 1) / threads.y; - int d3 = h + w; - dim3 blocks(d1, d2, d3); + int d3 = h + w - 1; + dim3 blocks(d1, d2, d3 * n); AT_DISPATCH_FLOATING_TYPES( weight.scalar_type(), "ca_map_backward_kernel_w", [&] { @@ -130,7 +130,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight, g.contiguous().data_ptr(), dw.contiguous().data_ptr(), n, c, h, w); }); - + d3 = c * n; + blocks = dim3(d1, d2, d3); AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_backward_kernel_g", [&] { ca_map_backward_kernel_g<<>>( dout.contiguous().data_ptr(), diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 0b88e55658..9a41a5ddf8 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -230,236 +230,237 @@ void border_align_backward(const Tensor &grad_output, const Tensor &boxes, const Tensor &argmax_idx, Tensor grad_input, const int pool_size); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), - py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"), - py::arg("down_y"), py::arg("pad_x0"), py::arg("pad_x1"), - py::arg("pad_y0"), py::arg("pad_y1")); - m.def("fused_bias_leakyrelu", &fused_bias_leakyrelu, - "fused_bias_leakyrelu (CUDA)", py::arg("input"), py::arg("bias"), - py::arg("empty"), py::arg("act"), py::arg("grad"), py::arg("alpha"), - py::arg("scale")); - m.def("get_compiler_version", &get_compiler_version, "get_compiler_version"); - m.def("get_compiling_cuda_version", &get_compiling_cuda_version, - "get_compiling_cuda_version"); - m.def("carafe_naive_forward", &carafe_naive_forward, "carafe_naive_forward", - py::arg("features"), py::arg("masks"), py::arg("output"), - py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor")); - m.def("carafe_naive_backward", &carafe_naive_backward, - "carafe_naive_backward", py::arg("top_grad"), py::arg("features"), - py::arg("masks"), py::arg("bottom_grad"), py::arg("mask_grad"), - py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor")); - m.def("carafe_forward", &carafe_forward, "carafe_forward", - py::arg("features"), py::arg("masks"), py::arg("rfeatures"), - py::arg("routput"), py::arg("rmasks"), py::arg("output"), - py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor")); - m.def("carafe_backward", &carafe_backward, "carafe_backward", - py::arg("top_grad"), py::arg("rfeatures"), py::arg("masks"), - py::arg("rtop_grad"), py::arg("rbottom_grad_hs"), - py::arg("rbottom_grad"), py::arg("rmask_grad"), py::arg("bottom_grad"), - py::arg("mask_grad"), py::arg("kernel_size"), py::arg("group_size"), - py::arg("scale_factor")); - m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward", - py::arg("input"), py::arg("weight"), py::arg("offset"), - py::arg("output"), py::arg("columns"), py::arg("ones"), py::arg("kW"), - py::arg("kH"), py::arg("dW"), py::arg("dH"), py::arg("padH"), - py::arg("padW"), py::arg("dilationW"), py::arg("dilationH"), - py::arg("group"), py::arg("deformable_group"), py::arg("im2col_step")); - m.def("deform_conv_backward_input", &deform_conv_backward_input, - "deform_conv_backward_input", py::arg("input"), py::arg("offset"), - py::arg("gradOutput"), py::arg("gradInput"), py::arg("gradOffset"), - py::arg("weight"), py::arg("columns"), py::arg("kW"), py::arg("kH"), - py::arg("dW"), py::arg("dH"), py::arg("padH"), py::arg("padW"), - py::arg("dilationW"), py::arg("dilationH"), py::arg("group"), - py::arg("deformable_group"), py::arg("im2col_step")); - m.def("deform_conv_backward_parameters", &deform_conv_backward_parameters, - "deform_conv_backward_parameters", py::arg("input"), py::arg("offset"), - py::arg("gradOutput"), py::arg("gradWeight"), py::arg("columns"), - py::arg("ones"), py::arg("kW"), py::arg("kH"), py::arg("dW"), - py::arg("dH"), py::arg("padH"), py::arg("padW"), py::arg("dilationW"), - py::arg("dilationH"), py::arg("group"), py::arg("deformable_group"), - py::arg("scale"), py::arg("im2col_step")); - m.def("deform_roi_pool_forward", &deform_roi_pool_forward, - "deform roi pool forward", py::arg("input"), py::arg("rois"), - py::arg("offset"), py::arg("output"), py::arg("pooled_height"), - py::arg("pooled_width"), py::arg("spatial_scale"), - py::arg("sampling_ratio"), py::arg("gamma")); - m.def("deform_roi_pool_backward", &deform_roi_pool_backward, - "deform roi pool backward", py::arg("grad_output"), py::arg("input"), - py::arg("rois"), py::arg("offset"), py::arg("grad_input"), - py::arg("grad_offset"), py::arg("pooled_height"), - py::arg("pooled_width"), py::arg("spatial_scale"), - py::arg("sampling_ratio"), py::arg("gamma")); - m.def("sigmoid_focal_loss_forward", &sigmoid_focal_loss_forward, - "sigmoid_focal_loss_forward ", py::arg("input"), py::arg("target"), - py::arg("weight"), py::arg("output"), py::arg("gamma"), - py::arg("alpha")); - m.def("sigmoid_focal_loss_backward", &sigmoid_focal_loss_backward, - "sigmoid_focal_loss_backward", py::arg("input"), py::arg("target"), - py::arg("weight"), py::arg("grad_input"), py::arg("gamma"), - py::arg("alpha")); - m.def("softmax_focal_loss_forward", &softmax_focal_loss_forward, - "softmax_focal_loss_forward", py::arg("input"), py::arg("target"), - py::arg("weight"), py::arg("output"), py::arg("gamma"), - py::arg("alpha")); - m.def("softmax_focal_loss_backward", &softmax_focal_loss_backward, - "softmax_focal_loss_backward", py::arg("input"), py::arg("target"), - py::arg("weight"), py::arg("buff"), py::arg("grad_input"), - py::arg("gamma"), py::arg("alpha")); - m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"), - py::arg("bboxes2"), py::arg("ious"), py::arg("mode"), - py::arg("aligned"), py::arg("offset")); - m.def("masked_im2col_forward", &masked_im2col_forward, - "masked_im2col_forward", py::arg("im"), py::arg("mask_h_idx"), - py::arg("mask_w_idx"), py::arg("col"), py::arg("kernel_h"), - py::arg("kernel_w"), py::arg("pad_h"), py::arg("pad_w")); - m.def("masked_col2im_forward", &masked_col2im_forward, - "masked_col2im_forward", py::arg("col"), py::arg("mask_h_idx"), - py::arg("mask_w_idx"), py::arg("im"), py::arg("height"), - py::arg("width"), py::arg("channels")); - m.def("modulated_deform_conv_forward", &modulated_deform_conv_forward, - "modulated deform conv forward", py::arg("input"), py::arg("weight"), - py::arg("bias"), py::arg("ones"), py::arg("offset"), py::arg("mask"), - py::arg("output"), py::arg("columns"), py::arg("kernel_h"), - py::arg("kernel_w"), py::arg("stride_h"), py::arg("stride_w"), - py::arg("pad_h"), py::arg("pad_w"), py::arg("dilation_h"), - py::arg("dilation_w"), py::arg("group"), py::arg("deformable_group"), - py::arg("with_bias")); - m.def("modulated_deform_conv_backward", &modulated_deform_conv_backward, - "modulated deform conv backward", py::arg("input"), py::arg("weight"), - py::arg("bias"), py::arg("ones"), py::arg("offset"), py::arg("mask"), - py::arg("columns"), py::arg("grad_input"), py::arg("grad_weight"), - py::arg("grad_bias"), py::arg("grad_offset"), py::arg("grad_mask"), - py::arg("grad_output"), py::arg("kernel_h"), py::arg("kernel_w"), - py::arg("stride_h"), py::arg("stride_w"), py::arg("pad_h"), - py::arg("pad_w"), py::arg("dilation_h"), py::arg("dilation_w"), - py::arg("group"), py::arg("deformable_group"), py::arg("with_bias")); - m.def("nms", &nms, "nms (CPU/CUDA) ", py::arg("boxes"), py::arg("scores"), - py::arg("iou_threshold"), py::arg("offset")); - m.def("softnms", &softnms, "softnms (CPU) ", py::arg("boxes"), - py::arg("scores"), py::arg("dets"), py::arg("iou_threshold"), - py::arg("sigma"), py::arg("min_score"), py::arg("method"), - py::arg("offset")); - m.def("nms_match", &nms_match, "nms_match (CPU) ", py::arg("dets"), - py::arg("iou_threshold")); - m.def("pixel_group", &pixel_group, "pixel group (CPU) ", py::arg("score"), - py::arg("mask"), py::arg("embedding"), py::arg("kernel_label"), - py::arg("kernel_contour"), py::arg("kernel_region_label"), - py::arg("distance_threshold")); - m.def("contour_expand", &contour_expand, "contour exapnd (CPU) ", - py::arg("kernel_mask"), py::arg("internal_kernel_label"), - py::arg("min_kernel_area"), py::arg("kernel_num")); - m.def("roi_align_forward", &roi_align_forward, "roi_align forward", - py::arg("input"), py::arg("rois"), py::arg("output"), - py::arg("argmax_y"), py::arg("argmax_x"), py::arg("aligned_height"), - py::arg("aligned_width"), py::arg("spatial_scale"), - py::arg("sampling_ratio"), py::arg("pool_mode"), py::arg("aligned")); - m.def("roi_align_backward", &roi_align_backward, "roi_align backward", - py::arg("grad_output"), py::arg("rois"), py::arg("argmax_y"), - py::arg("argmax_x"), py::arg("grad_input"), py::arg("aligned_height"), - py::arg("aligned_width"), py::arg("spatial_scale"), - py::arg("sampling_ratio"), py::arg("pool_mode"), py::arg("aligned")); - m.def("roi_pool_forward", &roi_pool_forward, "roi_pool forward", - py::arg("input"), py::arg("rois"), py::arg("output"), py::arg("argmax"), - py::arg("pooled_height"), py::arg("pooled_width"), - py::arg("spatial_scale")); - m.def("roi_pool_backward", &roi_pool_backward, "roi_pool backward", - py::arg("grad_output"), py::arg("rois"), py::arg("argmax"), - py::arg("grad_input"), py::arg("pooled_height"), - py::arg("pooled_width"), py::arg("spatial_scale")); - m.def("sync_bn_forward_mean", &sync_bn_forward_mean, "sync_bn forward_mean", - py::arg("input"), py::arg("mean")); - m.def("sync_bn_forward_var", &sync_bn_forward_var, "sync_bn forward_var", - py::arg("input"), py::arg("mean"), py::arg("var")); - m.def("sync_bn_forward_output", &sync_bn_forward_output, - "sync_bn forward_output", py::arg("input"), py::arg("mean"), - py::arg("var"), py::arg("weight"), py::arg("bias"), - py::arg("running_mean"), py::arg("running_var"), py::arg("norm"), - py::arg("std"), py::arg("output"), py::arg("eps"), py::arg("momentum"), - py::arg("group_size")); - m.def("sync_bn_backward_param", &sync_bn_backward_param, - "sync_bn backward_param", py::arg("grad_output"), py::arg("norm"), - py::arg("grad_weight"), py::arg("grad_bias")); - m.def("sync_bn_backward_data", &sync_bn_backward_data, - "sync_bn backward_data", py::arg("grad_output"), py::arg("weight"), - py::arg("grad_weight"), py::arg("grad_bias"), py::arg("norm"), - py::arg("std"), py::arg("grad_input")); - m.def("ca_forward", &ca_forward, "ccattention forward", py::arg("t"), - py::arg("f"), py::arg("weight")); - m.def("ca_backward", &ca_backward, "ccattention backward", py::arg("dw"), - py::arg("t"), py::arg("f"), py::arg("dt"), py::arg("df")); - m.def("ca_map_forward", &ca_map_forward, "ccattention map forward", - py::arg("weight"), py::arg("g"), py::arg("out")); - m.def("ca_map_backward", &ca_map_backward, "ccattention map backward", - py::arg("dout"), py::arg("weight"), py::arg("g"), py::arg("dw"), - py::arg("dg")); - m.def("psamask_forward", &psamask_forward, "PSAMASK forward (CPU/CUDA)", - py::arg("input"), py::arg("output"), py::arg("psa_type"), - py::arg("num_"), py::arg("h_feature"), py::arg("w_feature"), - py::arg("h_mask"), py::arg("w_mask"), py::arg("half_h_mask"), - py::arg("half_w_mask")); - m.def("psamask_backward", &psamask_backward, "PSAMASK backward (CPU/CUDA)", - py::arg("grad_output"), py::arg("grad_input"), py::arg("psa_type"), - py::arg("num_"), py::arg("h_feature"), py::arg("w_feature"), - py::arg("h_mask"), py::arg("w_mask"), py::arg("half_h_mask"), - py::arg("half_w_mask")); - m.def("tin_shift_forward", &tin_shift_forward, "tin_shift forward", - py::arg("input"), py::arg("shift"), py::arg("output")); - m.def("tin_shift_backward", &tin_shift_backward, "tin_shift backward", - py::arg("grad_output"), py::arg("shift"), py::arg("grad_input")); - m.def("bottom_pool_forward", &bottom_pool_forward, "Bottom Pool Forward", - py::arg("input"), py::call_guard()); - m.def("bottom_pool_backward", &bottom_pool_backward, "Bottom Pool Backward", - py::arg("input"), py::arg("grad_output"), - py::call_guard()); - m.def("left_pool_forward", &left_pool_forward, "Left Pool Forward", - py::arg("input"), py::call_guard()); - m.def("left_pool_backward", &left_pool_backward, "Left Pool Backward", - py::arg("input"), py::arg("grad_output"), - py::call_guard()); - m.def("right_pool_forward", &right_pool_forward, "Right Pool Forward", - py::arg("input"), py::call_guard()); - m.def("right_pool_backward", &right_pool_backward, "Right Pool Backward", - py::arg("input"), py::arg("grad_output"), - py::call_guard()); - m.def("top_pool_forward", &top_pool_forward, "Top Pool Forward", - py::arg("input"), py::call_guard()); - m.def("top_pool_backward", &top_pool_backward, "Top Pool Backward", - py::arg("input"), py::arg("grad_output"), - py::call_guard()); - m.def("box_iou_rotated", &box_iou_rotated, "IoU for rotated boxes", - py::arg("boxes1"), py::arg("boxes2"), py::arg("ious"), - py::arg("mode_flag"), py::arg("aligned")); - m.def("nms_rotated", &nms_rotated, "NMS for rotated boxes", py::arg("dets"), - py::arg("scores"), py::arg("order"), py::arg("dets_sorted"), - py::arg("iou_threshold"), py::arg("multi_label")); - m.def("roi_align_rotated_forward", &roi_align_rotated_forward, - "roi_align_rotated forward", py::arg("input"), py::arg("rois"), - py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"), - py::arg("spatial_scale"), py::arg("sample_num"), py::arg("aligned"), - py::arg("clockwise")); - m.def("roi_align_rotated_backward", &roi_align_rotated_backward, - "roi_align_rotated backward", py::arg("grad_output"), py::arg("rois"), - py::arg("grad_input"), py::arg("pooled_height"), - py::arg("pooled_width"), py::arg("spatial_scale"), - py::arg("sample_num"), py::arg("aligned"), py::arg("clockwise")); - m.def("ms_deform_attn_forward", &ms_deform_attn_forward, - "forward function of multi-scale deformable attention", - py::arg("value"), py::arg("value_spatial_shapes"), - py::arg("value_level_start_index"), py::arg("sampling_locations"), - py::arg("attention_weights"), py::arg("im2col_step")); - m.def("ms_deform_attn_backward", &ms_deform_attn_backward, - "backward function of multi-scale deformable attention", - py::arg("value"), py::arg("value_spatial_shapes"), - py::arg("value_level_start_index"), py::arg("sampling_locations"), - py::arg("attention_weights"), py::arg("grad_output"), - py::arg("grad_value"), py::arg("grad_sampling_loc"), - py::arg("grad_attn_weight"), py::arg("im2col_step")); - m.def("border_align_forward", &border_align_forward, - "forward function of border_align", py::arg("input"), py::arg("boxes"), - py::arg("output"), py::arg("argmax_idx"), py::arg("pool_size")); - m.def("border_align_backward", &border_align_backward, - "backward function of border_align", py::arg("grad_output"), - py::arg("boxes"), py::arg("argmax_idx"), py::arg("grad_input"), - py::arg("pool_size")); +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), + py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"), + py::arg("down_y"), py::arg("pad_x0"), py::arg("pad_x1"), + py::arg("pad_y0"), py::arg("pad_y1")); + m.def("fused_bias_leakyrelu", &fused_bias_leakyrelu, + "fused_bias_leakyrelu (CUDA)", py::arg("input"), py::arg("bias"), + py::arg("empty"), py::arg("act"), py::arg("grad"), py::arg("alpha"), + py::arg("scale")); + m.def("get_compiler_version", &get_compiler_version, "get_compiler_version"); + m.def("get_compiling_cuda_version", &get_compiling_cuda_version, + "get_compiling_cuda_version"); + m.def("carafe_naive_forward", &carafe_naive_forward, "carafe_naive_forward", + py::arg("features"), py::arg("masks"), py::arg("output"), + py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor")); + m.def("carafe_naive_backward", &carafe_naive_backward, + "carafe_naive_backward", py::arg("top_grad"), py::arg("features"), + py::arg("masks"), py::arg("bottom_grad"), py::arg("mask_grad"), + py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor")); + m.def("carafe_forward", &carafe_forward, "carafe_forward", + py::arg("features"), py::arg("masks"), py::arg("rfeatures"), + py::arg("routput"), py::arg("rmasks"), py::arg("output"), + py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor")); + m.def("carafe_backward", &carafe_backward, "carafe_backward", + py::arg("top_grad"), py::arg("rfeatures"), py::arg("masks"), + py::arg("rtop_grad"), py::arg("rbottom_grad_hs"), + py::arg("rbottom_grad"), py::arg("rmask_grad"), py::arg("bottom_grad"), + py::arg("mask_grad"), py::arg("kernel_size"), py::arg("group_size"), + py::arg("scale_factor")); + m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward", + py::arg("input"), py::arg("weight"), py::arg("offset"), + py::arg("output"), py::arg("columns"), py::arg("ones"), py::arg("kW"), + py::arg("kH"), py::arg("dW"), py::arg("dH"), py::arg("padH"), + py::arg("padW"), py::arg("dilationW"), py::arg("dilationH"), + py::arg("group"), py::arg("deformable_group"), py::arg("im2col_step")); + m.def("deform_conv_backward_input", &deform_conv_backward_input, + "deform_conv_backward_input", py::arg("input"), py::arg("offset"), + py::arg("gradOutput"), py::arg("gradInput"), py::arg("gradOffset"), + py::arg("weight"), py::arg("columns"), py::arg("kW"), py::arg("kH"), + py::arg("dW"), py::arg("dH"), py::arg("padH"), py::arg("padW"), + py::arg("dilationW"), py::arg("dilationH"), py::arg("group"), + py::arg("deformable_group"), py::arg("im2col_step")); + m.def("deform_conv_backward_parameters", &deform_conv_backward_parameters, + "deform_conv_backward_parameters", py::arg("input"), py::arg("offset"), + py::arg("gradOutput"), py::arg("gradWeight"), py::arg("columns"), + py::arg("ones"), py::arg("kW"), py::arg("kH"), py::arg("dW"), + py::arg("dH"), py::arg("padH"), py::arg("padW"), py::arg("dilationW"), + py::arg("dilationH"), py::arg("group"), py::arg("deformable_group"), + py::arg("scale"), py::arg("im2col_step")); + m.def("deform_roi_pool_forward", &deform_roi_pool_forward, + "deform roi pool forward", py::arg("input"), py::arg("rois"), + py::arg("offset"), py::arg("output"), py::arg("pooled_height"), + py::arg("pooled_width"), py::arg("spatial_scale"), + py::arg("sampling_ratio"), py::arg("gamma")); + m.def("deform_roi_pool_backward", &deform_roi_pool_backward, + "deform roi pool backward", py::arg("grad_output"), py::arg("input"), + py::arg("rois"), py::arg("offset"), py::arg("grad_input"), + py::arg("grad_offset"), py::arg("pooled_height"), + py::arg("pooled_width"), py::arg("spatial_scale"), + py::arg("sampling_ratio"), py::arg("gamma")); + m.def("sigmoid_focal_loss_forward", &sigmoid_focal_loss_forward, + "sigmoid_focal_loss_forward ", py::arg("input"), py::arg("target"), + py::arg("weight"), py::arg("output"), py::arg("gamma"), + py::arg("alpha")); + m.def("sigmoid_focal_loss_backward", &sigmoid_focal_loss_backward, + "sigmoid_focal_loss_backward", py::arg("input"), py::arg("target"), + py::arg("weight"), py::arg("grad_input"), py::arg("gamma"), + py::arg("alpha")); + m.def("softmax_focal_loss_forward", &softmax_focal_loss_forward, + "softmax_focal_loss_forward", py::arg("input"), py::arg("target"), + py::arg("weight"), py::arg("output"), py::arg("gamma"), + py::arg("alpha")); + m.def("softmax_focal_loss_backward", &softmax_focal_loss_backward, + "softmax_focal_loss_backward", py::arg("input"), py::arg("target"), + py::arg("weight"), py::arg("buff"), py::arg("grad_input"), + py::arg("gamma"), py::arg("alpha")); + m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"), + py::arg("bboxes2"), py::arg("ious"), py::arg("mode"), + py::arg("aligned"), py::arg("offset")); + m.def("masked_im2col_forward", &masked_im2col_forward, + "masked_im2col_forward", py::arg("im"), py::arg("mask_h_idx"), + py::arg("mask_w_idx"), py::arg("col"), py::arg("kernel_h"), + py::arg("kernel_w"), py::arg("pad_h"), py::arg("pad_w")); + m.def("masked_col2im_forward", &masked_col2im_forward, + "masked_col2im_forward", py::arg("col"), py::arg("mask_h_idx"), + py::arg("mask_w_idx"), py::arg("im"), py::arg("height"), + py::arg("width"), py::arg("channels")); + m.def("modulated_deform_conv_forward", &modulated_deform_conv_forward, + "modulated deform conv forward", py::arg("input"), py::arg("weight"), + py::arg("bias"), py::arg("ones"), py::arg("offset"), py::arg("mask"), + py::arg("output"), py::arg("columns"), py::arg("kernel_h"), + py::arg("kernel_w"), py::arg("stride_h"), py::arg("stride_w"), + py::arg("pad_h"), py::arg("pad_w"), py::arg("dilation_h"), + py::arg("dilation_w"), py::arg("group"), py::arg("deformable_group"), + py::arg("with_bias")); + m.def("modulated_deform_conv_backward", &modulated_deform_conv_backward, + "modulated deform conv backward", py::arg("input"), py::arg("weight"), + py::arg("bias"), py::arg("ones"), py::arg("offset"), py::arg("mask"), + py::arg("columns"), py::arg("grad_input"), py::arg("grad_weight"), + py::arg("grad_bias"), py::arg("grad_offset"), py::arg("grad_mask"), + py::arg("grad_output"), py::arg("kernel_h"), py::arg("kernel_w"), + py::arg("stride_h"), py::arg("stride_w"), py::arg("pad_h"), + py::arg("pad_w"), py::arg("dilation_h"), py::arg("dilation_w"), + py::arg("group"), py::arg("deformable_group"), py::arg("with_bias")); + m.def("nms", &nms, "nms (CPU/CUDA) ", py::arg("boxes"), py::arg("scores"), + py::arg("iou_threshold"), py::arg("offset")); + m.def("softnms", &softnms, "softnms (CPU) ", py::arg("boxes"), + py::arg("scores"), py::arg("dets"), py::arg("iou_threshold"), + py::arg("sigma"), py::arg("min_score"), py::arg("method"), + py::arg("offset")); + m.def("nms_match", &nms_match, "nms_match (CPU) ", py::arg("dets"), + py::arg("iou_threshold")); + m.def("pixel_group", &pixel_group, "pixel group (CPU) ", py::arg("score"), + py::arg("mask"), py::arg("embedding"), py::arg("kernel_label"), + py::arg("kernel_contour"), py::arg("kernel_region_label"), + py::arg("distance_threshold")); + m.def("contour_expand", &contour_expand, "contour exapnd (CPU) ", + py::arg("kernel_mask"), py::arg("internal_kernel_label"), + py::arg("min_kernel_area"), py::arg("kernel_num")); + m.def("roi_align_forward", &roi_align_forward, "roi_align forward", + py::arg("input"), py::arg("rois"), py::arg("output"), + py::arg("argmax_y"), py::arg("argmax_x"), py::arg("aligned_height"), + py::arg("aligned_width"), py::arg("spatial_scale"), + py::arg("sampling_ratio"), py::arg("pool_mode"), py::arg("aligned")); + m.def("roi_align_backward", &roi_align_backward, "roi_align backward", + py::arg("grad_output"), py::arg("rois"), py::arg("argmax_y"), + py::arg("argmax_x"), py::arg("grad_input"), py::arg("aligned_height"), + py::arg("aligned_width"), py::arg("spatial_scale"), + py::arg("sampling_ratio"), py::arg("pool_mode"), py::arg("aligned")); + m.def("roi_pool_forward", &roi_pool_forward, "roi_pool forward", + py::arg("input"), py::arg("rois"), py::arg("output"), py::arg("argmax"), + py::arg("pooled_height"), py::arg("pooled_width"), + py::arg("spatial_scale")); + m.def("roi_pool_backward", &roi_pool_backward, "roi_pool backward", + py::arg("grad_output"), py::arg("rois"), py::arg("argmax"), + py::arg("grad_input"), py::arg("pooled_height"), + py::arg("pooled_width"), py::arg("spatial_scale")); + m.def("sync_bn_forward_mean", &sync_bn_forward_mean, "sync_bn forward_mean", + py::arg("input"), py::arg("mean")); + m.def("sync_bn_forward_var", &sync_bn_forward_var, "sync_bn forward_var", + py::arg("input"), py::arg("mean"), py::arg("var")); + m.def("sync_bn_forward_output", &sync_bn_forward_output, + "sync_bn forward_output", py::arg("input"), py::arg("mean"), + py::arg("var"), py::arg("weight"), py::arg("bias"), + py::arg("running_mean"), py::arg("running_var"), py::arg("norm"), + py::arg("std"), py::arg("output"), py::arg("eps"), py::arg("momentum"), + py::arg("group_size")); + m.def("sync_bn_backward_param", &sync_bn_backward_param, + "sync_bn backward_param", py::arg("grad_output"), py::arg("norm"), + py::arg("grad_weight"), py::arg("grad_bias")); + m.def("sync_bn_backward_data", &sync_bn_backward_data, + "sync_bn backward_data", py::arg("grad_output"), py::arg("weight"), + py::arg("grad_weight"), py::arg("grad_bias"), py::arg("norm"), + py::arg("std"), py::arg("grad_input")); + m.def("ca_forward", &ca_forward, "ccattention forward", py::arg("t"), + py::arg("f"), py::arg("weight")); + m.def("ca_backward", &ca_backward, "ccattention backward", py::arg("dw"), + py::arg("t"), py::arg("f"), py::arg("dt"), py::arg("df")); + m.def("ca_map_forward", &ca_map_forward, "ccattention map forward", + py::arg("weight"), py::arg("g"), py::arg("out")); + m.def("ca_map_backward", &ca_map_backward, "ccattention map backward", + py::arg("dout"), py::arg("weight"), py::arg("g"), py::arg("dw"), + py::arg("dg")); + m.def("psamask_forward", &psamask_forward, "PSAMASK forward (CPU/CUDA)", + py::arg("input"), py::arg("output"), py::arg("psa_type"), + py::arg("num_"), py::arg("h_feature"), py::arg("w_feature"), + py::arg("h_mask"), py::arg("w_mask"), py::arg("half_h_mask"), + py::arg("half_w_mask")); + m.def("psamask_backward", &psamask_backward, "PSAMASK backward (CPU/CUDA)", + py::arg("grad_output"), py::arg("grad_input"), py::arg("psa_type"), + py::arg("num_"), py::arg("h_feature"), py::arg("w_feature"), + py::arg("h_mask"), py::arg("w_mask"), py::arg("half_h_mask"), + py::arg("half_w_mask")); + m.def("tin_shift_forward", &tin_shift_forward, "tin_shift forward", + py::arg("input"), py::arg("shift"), py::arg("output")); + m.def("tin_shift_backward", &tin_shift_backward, "tin_shift backward", + py::arg("grad_output"), py::arg("shift"), py::arg("grad_input")); + m.def("bottom_pool_forward", &bottom_pool_forward, "Bottom Pool Forward", + py::arg("input"), py::call_guard()); + m.def("bottom_pool_backward", &bottom_pool_backward, "Bottom Pool Backward", + py::arg("input"), py::arg("grad_output"), + py::call_guard()); + m.def("left_pool_forward", &left_pool_forward, "Left Pool Forward", + py::arg("input"), py::call_guard()); + m.def("left_pool_backward", &left_pool_backward, "Left Pool Backward", + py::arg("input"), py::arg("grad_output"), + py::call_guard()); + m.def("right_pool_forward", &right_pool_forward, "Right Pool Forward", + py::arg("input"), py::call_guard()); + m.def("right_pool_backward", &right_pool_backward, "Right Pool Backward", + py::arg("input"), py::arg("grad_output"), + py::call_guard()); + m.def("top_pool_forward", &top_pool_forward, "Top Pool Forward", + py::arg("input"), py::call_guard()); + m.def("top_pool_backward", &top_pool_backward, "Top Pool Backward", + py::arg("input"), py::arg("grad_output"), + py::call_guard()); + m.def("box_iou_rotated", &box_iou_rotated, "IoU for rotated boxes", + py::arg("boxes1"), py::arg("boxes2"), py::arg("ious"), + py::arg("mode_flag"), py::arg("aligned")); + m.def("nms_rotated", &nms_rotated, "NMS for rotated boxes", py::arg("dets"), + py::arg("scores"), py::arg("order"), py::arg("dets_sorted"), + py::arg("iou_threshold"), py::arg("multi_label")); + m.def("roi_align_rotated_forward", &roi_align_rotated_forward, + "roi_align_rotated forward", py::arg("input"), py::arg("rois"), + py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"), + py::arg("spatial_scale"), py::arg("sample_num"), py::arg("aligned"), + py::arg("clockwise")); + m.def("roi_align_rotated_backward", &roi_align_rotated_backward, + "roi_align_rotated backward", py::arg("grad_output"), py::arg("rois"), + py::arg("grad_input"), py::arg("pooled_height"), + py::arg("pooled_width"), py::arg("spatial_scale"), + py::arg("sample_num"), py::arg("aligned"), py::arg("clockwise")); + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, + "forward function of multi-scale deformable attention", + py::arg("value"), py::arg("value_spatial_shapes"), + py::arg("value_level_start_index"), py::arg("sampling_locations"), + py::arg("attention_weights"), py::arg("im2col_step")); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, + "backward function of multi-scale deformable attention", + py::arg("value"), py::arg("value_spatial_shapes"), + py::arg("value_level_start_index"), py::arg("sampling_locations"), + py::arg("attention_weights"), py::arg("grad_output"), + py::arg("grad_value"), py::arg("grad_sampling_loc"), + py::arg("grad_attn_weight"), py::arg("im2col_step")); + m.def("border_align_forward", &border_align_forward, + "forward function of border_align", py::arg("input"), py::arg("boxes"), + py::arg("output"), py::arg("argmax_idx"), py::arg("pool_size")); + m.def("border_align_backward", &border_align_backward, + "backward function of border_align", py::arg("grad_output"), + py::arg("boxes"), py::arg("argmax_idx"), py::arg("grad_input"), + py::arg("pool_size")); } From a20c2092647474f9a7042b37df7ecff92bc67d52 Mon Sep 17 00:00:00 2001 From: wondervictor Date: Tue, 8 Jun 2021 16:04:59 +0800 Subject: [PATCH 2/4] optimize criss cross attention --- .editorconfig | 59 ++++ mmcv/ops/csrc/pytorch/pybind.cpp | 467 +++++++++++++++---------------- 2 files changed, 292 insertions(+), 234 deletions(-) create mode 100644 .editorconfig diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000..163988705d --- /dev/null +++ b/.editorconfig @@ -0,0 +1,59 @@ +[*] +cpp_indent_braces=false +cpp_indent_multi_line_relative_to=innermost_parenthesis +cpp_indent_within_parentheses=indent +cpp_indent_preserve_within_parentheses=false +cpp_indent_case_labels=false +cpp_indent_case_contents=true +cpp_indent_case_contents_when_block=false +cpp_indent_lambda_braces_when_parameter=true +cpp_indent_goto_labels=one_left +cpp_indent_preprocessor=leftmost_column +cpp_indent_access_specifiers=false +cpp_indent_namespace_contents=true +cpp_indent_preserve_comments=false +cpp_new_line_before_open_brace_namespace=ignore +cpp_new_line_before_open_brace_type=ignore +cpp_new_line_before_open_brace_function=ignore +cpp_new_line_before_open_brace_block=ignore +cpp_new_line_before_open_brace_lambda=ignore +cpp_new_line_scope_braces_on_separate_lines=false +cpp_new_line_close_brace_same_line_empty_type=false +cpp_new_line_close_brace_same_line_empty_function=false +cpp_new_line_before_catch=true +cpp_new_line_before_else=true +cpp_new_line_before_while_in_do_while=false +cpp_space_before_function_open_parenthesis=remove +cpp_space_within_parameter_list_parentheses=false +cpp_space_between_empty_parameter_list_parentheses=false +cpp_space_after_keywords_in_control_flow_statements=true +cpp_space_within_control_flow_statement_parentheses=false +cpp_space_before_lambda_open_parenthesis=false +cpp_space_within_cast_parentheses=false +cpp_space_after_cast_close_parenthesis=false +cpp_space_within_expression_parentheses=false +cpp_space_before_block_open_brace=true +cpp_space_between_empty_braces=false +cpp_space_before_initializer_list_open_brace=false +cpp_space_within_initializer_list_braces=true +cpp_space_preserve_in_initializer_list=true +cpp_space_before_open_square_bracket=false +cpp_space_within_square_brackets=false +cpp_space_before_empty_square_brackets=false +cpp_space_between_empty_square_brackets=false +cpp_space_group_square_brackets=true +cpp_space_within_lambda_brackets=false +cpp_space_between_empty_lambda_brackets=false +cpp_space_before_comma=false +cpp_space_after_comma=true +cpp_space_remove_around_member_operators=true +cpp_space_before_inheritance_colon=true +cpp_space_before_constructor_colon=true +cpp_space_remove_before_semicolon=true +cpp_space_after_semicolon=false +cpp_space_remove_around_unary_operator=true +cpp_space_around_binary_operator=insert +cpp_space_around_assignment_operator=insert +cpp_space_pointer_reference_alignment=left +cpp_space_around_ternary_operator=insert +cpp_wrap_preserve_blocks=one_liners diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 9a41a5ddf8..b36b787630 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -230,237 +230,236 @@ void border_align_backward(const Tensor &grad_output, const Tensor &boxes, const Tensor &argmax_idx, Tensor grad_input, const int pool_size); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), - py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"), - py::arg("down_y"), py::arg("pad_x0"), py::arg("pad_x1"), - py::arg("pad_y0"), py::arg("pad_y1")); - m.def("fused_bias_leakyrelu", &fused_bias_leakyrelu, - "fused_bias_leakyrelu (CUDA)", py::arg("input"), py::arg("bias"), - py::arg("empty"), py::arg("act"), py::arg("grad"), py::arg("alpha"), - py::arg("scale")); - m.def("get_compiler_version", &get_compiler_version, "get_compiler_version"); - m.def("get_compiling_cuda_version", &get_compiling_cuda_version, - "get_compiling_cuda_version"); - m.def("carafe_naive_forward", &carafe_naive_forward, "carafe_naive_forward", - py::arg("features"), py::arg("masks"), py::arg("output"), - py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor")); - m.def("carafe_naive_backward", &carafe_naive_backward, - "carafe_naive_backward", py::arg("top_grad"), py::arg("features"), - py::arg("masks"), py::arg("bottom_grad"), py::arg("mask_grad"), - py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor")); - m.def("carafe_forward", &carafe_forward, "carafe_forward", - py::arg("features"), py::arg("masks"), py::arg("rfeatures"), - py::arg("routput"), py::arg("rmasks"), py::arg("output"), - py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor")); - m.def("carafe_backward", &carafe_backward, "carafe_backward", - py::arg("top_grad"), py::arg("rfeatures"), py::arg("masks"), - py::arg("rtop_grad"), py::arg("rbottom_grad_hs"), - py::arg("rbottom_grad"), py::arg("rmask_grad"), py::arg("bottom_grad"), - py::arg("mask_grad"), py::arg("kernel_size"), py::arg("group_size"), - py::arg("scale_factor")); - m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward", - py::arg("input"), py::arg("weight"), py::arg("offset"), - py::arg("output"), py::arg("columns"), py::arg("ones"), py::arg("kW"), - py::arg("kH"), py::arg("dW"), py::arg("dH"), py::arg("padH"), - py::arg("padW"), py::arg("dilationW"), py::arg("dilationH"), - py::arg("group"), py::arg("deformable_group"), py::arg("im2col_step")); - m.def("deform_conv_backward_input", &deform_conv_backward_input, - "deform_conv_backward_input", py::arg("input"), py::arg("offset"), - py::arg("gradOutput"), py::arg("gradInput"), py::arg("gradOffset"), - py::arg("weight"), py::arg("columns"), py::arg("kW"), py::arg("kH"), - py::arg("dW"), py::arg("dH"), py::arg("padH"), py::arg("padW"), - py::arg("dilationW"), py::arg("dilationH"), py::arg("group"), - py::arg("deformable_group"), py::arg("im2col_step")); - m.def("deform_conv_backward_parameters", &deform_conv_backward_parameters, - "deform_conv_backward_parameters", py::arg("input"), py::arg("offset"), - py::arg("gradOutput"), py::arg("gradWeight"), py::arg("columns"), - py::arg("ones"), py::arg("kW"), py::arg("kH"), py::arg("dW"), - py::arg("dH"), py::arg("padH"), py::arg("padW"), py::arg("dilationW"), - py::arg("dilationH"), py::arg("group"), py::arg("deformable_group"), - py::arg("scale"), py::arg("im2col_step")); - m.def("deform_roi_pool_forward", &deform_roi_pool_forward, - "deform roi pool forward", py::arg("input"), py::arg("rois"), - py::arg("offset"), py::arg("output"), py::arg("pooled_height"), - py::arg("pooled_width"), py::arg("spatial_scale"), - py::arg("sampling_ratio"), py::arg("gamma")); - m.def("deform_roi_pool_backward", &deform_roi_pool_backward, - "deform roi pool backward", py::arg("grad_output"), py::arg("input"), - py::arg("rois"), py::arg("offset"), py::arg("grad_input"), - py::arg("grad_offset"), py::arg("pooled_height"), - py::arg("pooled_width"), py::arg("spatial_scale"), - py::arg("sampling_ratio"), py::arg("gamma")); - m.def("sigmoid_focal_loss_forward", &sigmoid_focal_loss_forward, - "sigmoid_focal_loss_forward ", py::arg("input"), py::arg("target"), - py::arg("weight"), py::arg("output"), py::arg("gamma"), - py::arg("alpha")); - m.def("sigmoid_focal_loss_backward", &sigmoid_focal_loss_backward, - "sigmoid_focal_loss_backward", py::arg("input"), py::arg("target"), - py::arg("weight"), py::arg("grad_input"), py::arg("gamma"), - py::arg("alpha")); - m.def("softmax_focal_loss_forward", &softmax_focal_loss_forward, - "softmax_focal_loss_forward", py::arg("input"), py::arg("target"), - py::arg("weight"), py::arg("output"), py::arg("gamma"), - py::arg("alpha")); - m.def("softmax_focal_loss_backward", &softmax_focal_loss_backward, - "softmax_focal_loss_backward", py::arg("input"), py::arg("target"), - py::arg("weight"), py::arg("buff"), py::arg("grad_input"), - py::arg("gamma"), py::arg("alpha")); - m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"), - py::arg("bboxes2"), py::arg("ious"), py::arg("mode"), - py::arg("aligned"), py::arg("offset")); - m.def("masked_im2col_forward", &masked_im2col_forward, - "masked_im2col_forward", py::arg("im"), py::arg("mask_h_idx"), - py::arg("mask_w_idx"), py::arg("col"), py::arg("kernel_h"), - py::arg("kernel_w"), py::arg("pad_h"), py::arg("pad_w")); - m.def("masked_col2im_forward", &masked_col2im_forward, - "masked_col2im_forward", py::arg("col"), py::arg("mask_h_idx"), - py::arg("mask_w_idx"), py::arg("im"), py::arg("height"), - py::arg("width"), py::arg("channels")); - m.def("modulated_deform_conv_forward", &modulated_deform_conv_forward, - "modulated deform conv forward", py::arg("input"), py::arg("weight"), - py::arg("bias"), py::arg("ones"), py::arg("offset"), py::arg("mask"), - py::arg("output"), py::arg("columns"), py::arg("kernel_h"), - py::arg("kernel_w"), py::arg("stride_h"), py::arg("stride_w"), - py::arg("pad_h"), py::arg("pad_w"), py::arg("dilation_h"), - py::arg("dilation_w"), py::arg("group"), py::arg("deformable_group"), - py::arg("with_bias")); - m.def("modulated_deform_conv_backward", &modulated_deform_conv_backward, - "modulated deform conv backward", py::arg("input"), py::arg("weight"), - py::arg("bias"), py::arg("ones"), py::arg("offset"), py::arg("mask"), - py::arg("columns"), py::arg("grad_input"), py::arg("grad_weight"), - py::arg("grad_bias"), py::arg("grad_offset"), py::arg("grad_mask"), - py::arg("grad_output"), py::arg("kernel_h"), py::arg("kernel_w"), - py::arg("stride_h"), py::arg("stride_w"), py::arg("pad_h"), - py::arg("pad_w"), py::arg("dilation_h"), py::arg("dilation_w"), - py::arg("group"), py::arg("deformable_group"), py::arg("with_bias")); - m.def("nms", &nms, "nms (CPU/CUDA) ", py::arg("boxes"), py::arg("scores"), - py::arg("iou_threshold"), py::arg("offset")); - m.def("softnms", &softnms, "softnms (CPU) ", py::arg("boxes"), - py::arg("scores"), py::arg("dets"), py::arg("iou_threshold"), - py::arg("sigma"), py::arg("min_score"), py::arg("method"), - py::arg("offset")); - m.def("nms_match", &nms_match, "nms_match (CPU) ", py::arg("dets"), - py::arg("iou_threshold")); - m.def("pixel_group", &pixel_group, "pixel group (CPU) ", py::arg("score"), - py::arg("mask"), py::arg("embedding"), py::arg("kernel_label"), - py::arg("kernel_contour"), py::arg("kernel_region_label"), - py::arg("distance_threshold")); - m.def("contour_expand", &contour_expand, "contour exapnd (CPU) ", - py::arg("kernel_mask"), py::arg("internal_kernel_label"), - py::arg("min_kernel_area"), py::arg("kernel_num")); - m.def("roi_align_forward", &roi_align_forward, "roi_align forward", - py::arg("input"), py::arg("rois"), py::arg("output"), - py::arg("argmax_y"), py::arg("argmax_x"), py::arg("aligned_height"), - py::arg("aligned_width"), py::arg("spatial_scale"), - py::arg("sampling_ratio"), py::arg("pool_mode"), py::arg("aligned")); - m.def("roi_align_backward", &roi_align_backward, "roi_align backward", - py::arg("grad_output"), py::arg("rois"), py::arg("argmax_y"), - py::arg("argmax_x"), py::arg("grad_input"), py::arg("aligned_height"), - py::arg("aligned_width"), py::arg("spatial_scale"), - py::arg("sampling_ratio"), py::arg("pool_mode"), py::arg("aligned")); - m.def("roi_pool_forward", &roi_pool_forward, "roi_pool forward", - py::arg("input"), py::arg("rois"), py::arg("output"), py::arg("argmax"), - py::arg("pooled_height"), py::arg("pooled_width"), - py::arg("spatial_scale")); - m.def("roi_pool_backward", &roi_pool_backward, "roi_pool backward", - py::arg("grad_output"), py::arg("rois"), py::arg("argmax"), - py::arg("grad_input"), py::arg("pooled_height"), - py::arg("pooled_width"), py::arg("spatial_scale")); - m.def("sync_bn_forward_mean", &sync_bn_forward_mean, "sync_bn forward_mean", - py::arg("input"), py::arg("mean")); - m.def("sync_bn_forward_var", &sync_bn_forward_var, "sync_bn forward_var", - py::arg("input"), py::arg("mean"), py::arg("var")); - m.def("sync_bn_forward_output", &sync_bn_forward_output, - "sync_bn forward_output", py::arg("input"), py::arg("mean"), - py::arg("var"), py::arg("weight"), py::arg("bias"), - py::arg("running_mean"), py::arg("running_var"), py::arg("norm"), - py::arg("std"), py::arg("output"), py::arg("eps"), py::arg("momentum"), - py::arg("group_size")); - m.def("sync_bn_backward_param", &sync_bn_backward_param, - "sync_bn backward_param", py::arg("grad_output"), py::arg("norm"), - py::arg("grad_weight"), py::arg("grad_bias")); - m.def("sync_bn_backward_data", &sync_bn_backward_data, - "sync_bn backward_data", py::arg("grad_output"), py::arg("weight"), - py::arg("grad_weight"), py::arg("grad_bias"), py::arg("norm"), - py::arg("std"), py::arg("grad_input")); - m.def("ca_forward", &ca_forward, "ccattention forward", py::arg("t"), - py::arg("f"), py::arg("weight")); - m.def("ca_backward", &ca_backward, "ccattention backward", py::arg("dw"), - py::arg("t"), py::arg("f"), py::arg("dt"), py::arg("df")); - m.def("ca_map_forward", &ca_map_forward, "ccattention map forward", - py::arg("weight"), py::arg("g"), py::arg("out")); - m.def("ca_map_backward", &ca_map_backward, "ccattention map backward", - py::arg("dout"), py::arg("weight"), py::arg("g"), py::arg("dw"), - py::arg("dg")); - m.def("psamask_forward", &psamask_forward, "PSAMASK forward (CPU/CUDA)", - py::arg("input"), py::arg("output"), py::arg("psa_type"), - py::arg("num_"), py::arg("h_feature"), py::arg("w_feature"), - py::arg("h_mask"), py::arg("w_mask"), py::arg("half_h_mask"), - py::arg("half_w_mask")); - m.def("psamask_backward", &psamask_backward, "PSAMASK backward (CPU/CUDA)", - py::arg("grad_output"), py::arg("grad_input"), py::arg("psa_type"), - py::arg("num_"), py::arg("h_feature"), py::arg("w_feature"), - py::arg("h_mask"), py::arg("w_mask"), py::arg("half_h_mask"), - py::arg("half_w_mask")); - m.def("tin_shift_forward", &tin_shift_forward, "tin_shift forward", - py::arg("input"), py::arg("shift"), py::arg("output")); - m.def("tin_shift_backward", &tin_shift_backward, "tin_shift backward", - py::arg("grad_output"), py::arg("shift"), py::arg("grad_input")); - m.def("bottom_pool_forward", &bottom_pool_forward, "Bottom Pool Forward", - py::arg("input"), py::call_guard()); - m.def("bottom_pool_backward", &bottom_pool_backward, "Bottom Pool Backward", - py::arg("input"), py::arg("grad_output"), - py::call_guard()); - m.def("left_pool_forward", &left_pool_forward, "Left Pool Forward", - py::arg("input"), py::call_guard()); - m.def("left_pool_backward", &left_pool_backward, "Left Pool Backward", - py::arg("input"), py::arg("grad_output"), - py::call_guard()); - m.def("right_pool_forward", &right_pool_forward, "Right Pool Forward", - py::arg("input"), py::call_guard()); - m.def("right_pool_backward", &right_pool_backward, "Right Pool Backward", - py::arg("input"), py::arg("grad_output"), - py::call_guard()); - m.def("top_pool_forward", &top_pool_forward, "Top Pool Forward", - py::arg("input"), py::call_guard()); - m.def("top_pool_backward", &top_pool_backward, "Top Pool Backward", - py::arg("input"), py::arg("grad_output"), - py::call_guard()); - m.def("box_iou_rotated", &box_iou_rotated, "IoU for rotated boxes", - py::arg("boxes1"), py::arg("boxes2"), py::arg("ious"), - py::arg("mode_flag"), py::arg("aligned")); - m.def("nms_rotated", &nms_rotated, "NMS for rotated boxes", py::arg("dets"), - py::arg("scores"), py::arg("order"), py::arg("dets_sorted"), - py::arg("iou_threshold"), py::arg("multi_label")); - m.def("roi_align_rotated_forward", &roi_align_rotated_forward, - "roi_align_rotated forward", py::arg("input"), py::arg("rois"), - py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"), - py::arg("spatial_scale"), py::arg("sample_num"), py::arg("aligned"), - py::arg("clockwise")); - m.def("roi_align_rotated_backward", &roi_align_rotated_backward, - "roi_align_rotated backward", py::arg("grad_output"), py::arg("rois"), - py::arg("grad_input"), py::arg("pooled_height"), - py::arg("pooled_width"), py::arg("spatial_scale"), - py::arg("sample_num"), py::arg("aligned"), py::arg("clockwise")); - m.def("ms_deform_attn_forward", &ms_deform_attn_forward, - "forward function of multi-scale deformable attention", - py::arg("value"), py::arg("value_spatial_shapes"), - py::arg("value_level_start_index"), py::arg("sampling_locations"), - py::arg("attention_weights"), py::arg("im2col_step")); - m.def("ms_deform_attn_backward", &ms_deform_attn_backward, - "backward function of multi-scale deformable attention", - py::arg("value"), py::arg("value_spatial_shapes"), - py::arg("value_level_start_index"), py::arg("sampling_locations"), - py::arg("attention_weights"), py::arg("grad_output"), - py::arg("grad_value"), py::arg("grad_sampling_loc"), - py::arg("grad_attn_weight"), py::arg("im2col_step")); - m.def("border_align_forward", &border_align_forward, - "forward function of border_align", py::arg("input"), py::arg("boxes"), - py::arg("output"), py::arg("argmax_idx"), py::arg("pool_size")); - m.def("border_align_backward", &border_align_backward, - "backward function of border_align", py::arg("grad_output"), - py::arg("boxes"), py::arg("argmax_idx"), py::arg("grad_input"), - py::arg("pool_size")); -} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), + py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"), + py::arg("down_y"), py::arg("pad_x0"), py::arg("pad_x1"), + py::arg("pad_y0"), py::arg("pad_y1")); + m.def("fused_bias_leakyrelu", &fused_bias_leakyrelu, + "fused_bias_leakyrelu (CUDA)", py::arg("input"), py::arg("bias"), + py::arg("empty"), py::arg("act"), py::arg("grad"), py::arg("alpha"), + py::arg("scale")); + m.def("get_compiler_version", &get_compiler_version, "get_compiler_version"); + m.def("get_compiling_cuda_version", &get_compiling_cuda_version, + "get_compiling_cuda_version"); + m.def("carafe_naive_forward", &carafe_naive_forward, "carafe_naive_forward", + py::arg("features"), py::arg("masks"), py::arg("output"), + py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor")); + m.def("carafe_naive_backward", &carafe_naive_backward, + "carafe_naive_backward", py::arg("top_grad"), py::arg("features"), + py::arg("masks"), py::arg("bottom_grad"), py::arg("mask_grad"), + py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor")); + m.def("carafe_forward", &carafe_forward, "carafe_forward", + py::arg("features"), py::arg("masks"), py::arg("rfeatures"), + py::arg("routput"), py::arg("rmasks"), py::arg("output"), + py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor")); + m.def("carafe_backward", &carafe_backward, "carafe_backward", + py::arg("top_grad"), py::arg("rfeatures"), py::arg("masks"), + py::arg("rtop_grad"), py::arg("rbottom_grad_hs"), + py::arg("rbottom_grad"), py::arg("rmask_grad"), py::arg("bottom_grad"), + py::arg("mask_grad"), py::arg("kernel_size"), py::arg("group_size"), + py::arg("scale_factor")); + m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward", + py::arg("input"), py::arg("weight"), py::arg("offset"), + py::arg("output"), py::arg("columns"), py::arg("ones"), py::arg("kW"), + py::arg("kH"), py::arg("dW"), py::arg("dH"), py::arg("padH"), + py::arg("padW"), py::arg("dilationW"), py::arg("dilationH"), + py::arg("group"), py::arg("deformable_group"), py::arg("im2col_step")); + m.def("deform_conv_backward_input", &deform_conv_backward_input, + "deform_conv_backward_input", py::arg("input"), py::arg("offset"), + py::arg("gradOutput"), py::arg("gradInput"), py::arg("gradOffset"), + py::arg("weight"), py::arg("columns"), py::arg("kW"), py::arg("kH"), + py::arg("dW"), py::arg("dH"), py::arg("padH"), py::arg("padW"), + py::arg("dilationW"), py::arg("dilationH"), py::arg("group"), + py::arg("deformable_group"), py::arg("im2col_step")); + m.def("deform_conv_backward_parameters", &deform_conv_backward_parameters, + "deform_conv_backward_parameters", py::arg("input"), py::arg("offset"), + py::arg("gradOutput"), py::arg("gradWeight"), py::arg("columns"), + py::arg("ones"), py::arg("kW"), py::arg("kH"), py::arg("dW"), + py::arg("dH"), py::arg("padH"), py::arg("padW"), py::arg("dilationW"), + py::arg("dilationH"), py::arg("group"), py::arg("deformable_group"), + py::arg("scale"), py::arg("im2col_step")); + m.def("deform_roi_pool_forward", &deform_roi_pool_forward, + "deform roi pool forward", py::arg("input"), py::arg("rois"), + py::arg("offset"), py::arg("output"), py::arg("pooled_height"), + py::arg("pooled_width"), py::arg("spatial_scale"), + py::arg("sampling_ratio"), py::arg("gamma")); + m.def("deform_roi_pool_backward", &deform_roi_pool_backward, + "deform roi pool backward", py::arg("grad_output"), py::arg("input"), + py::arg("rois"), py::arg("offset"), py::arg("grad_input"), + py::arg("grad_offset"), py::arg("pooled_height"), + py::arg("pooled_width"), py::arg("spatial_scale"), + py::arg("sampling_ratio"), py::arg("gamma")); + m.def("sigmoid_focal_loss_forward", &sigmoid_focal_loss_forward, + "sigmoid_focal_loss_forward ", py::arg("input"), py::arg("target"), + py::arg("weight"), py::arg("output"), py::arg("gamma"), + py::arg("alpha")); + m.def("sigmoid_focal_loss_backward", &sigmoid_focal_loss_backward, + "sigmoid_focal_loss_backward", py::arg("input"), py::arg("target"), + py::arg("weight"), py::arg("grad_input"), py::arg("gamma"), + py::arg("alpha")); + m.def("softmax_focal_loss_forward", &softmax_focal_loss_forward, + "softmax_focal_loss_forward", py::arg("input"), py::arg("target"), + py::arg("weight"), py::arg("output"), py::arg("gamma"), + py::arg("alpha")); + m.def("softmax_focal_loss_backward", &softmax_focal_loss_backward, + "softmax_focal_loss_backward", py::arg("input"), py::arg("target"), + py::arg("weight"), py::arg("buff"), py::arg("grad_input"), + py::arg("gamma"), py::arg("alpha")); + m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"), + py::arg("bboxes2"), py::arg("ious"), py::arg("mode"), + py::arg("aligned"), py::arg("offset")); + m.def("masked_im2col_forward", &masked_im2col_forward, + "masked_im2col_forward", py::arg("im"), py::arg("mask_h_idx"), + py::arg("mask_w_idx"), py::arg("col"), py::arg("kernel_h"), + py::arg("kernel_w"), py::arg("pad_h"), py::arg("pad_w")); + m.def("masked_col2im_forward", &masked_col2im_forward, + "masked_col2im_forward", py::arg("col"), py::arg("mask_h_idx"), + py::arg("mask_w_idx"), py::arg("im"), py::arg("height"), + py::arg("width"), py::arg("channels")); + m.def("modulated_deform_conv_forward", &modulated_deform_conv_forward, + "modulated deform conv forward", py::arg("input"), py::arg("weight"), + py::arg("bias"), py::arg("ones"), py::arg("offset"), py::arg("mask"), + py::arg("output"), py::arg("columns"), py::arg("kernel_h"), + py::arg("kernel_w"), py::arg("stride_h"), py::arg("stride_w"), + py::arg("pad_h"), py::arg("pad_w"), py::arg("dilation_h"), + py::arg("dilation_w"), py::arg("group"), py::arg("deformable_group"), + py::arg("with_bias")); + m.def("modulated_deform_conv_backward", &modulated_deform_conv_backward, + "modulated deform conv backward", py::arg("input"), py::arg("weight"), + py::arg("bias"), py::arg("ones"), py::arg("offset"), py::arg("mask"), + py::arg("columns"), py::arg("grad_input"), py::arg("grad_weight"), + py::arg("grad_bias"), py::arg("grad_offset"), py::arg("grad_mask"), + py::arg("grad_output"), py::arg("kernel_h"), py::arg("kernel_w"), + py::arg("stride_h"), py::arg("stride_w"), py::arg("pad_h"), + py::arg("pad_w"), py::arg("dilation_h"), py::arg("dilation_w"), + py::arg("group"), py::arg("deformable_group"), py::arg("with_bias")); + m.def("nms", &nms, "nms (CPU/CUDA) ", py::arg("boxes"), py::arg("scores"), + py::arg("iou_threshold"), py::arg("offset")); + m.def("softnms", &softnms, "softnms (CPU) ", py::arg("boxes"), + py::arg("scores"), py::arg("dets"), py::arg("iou_threshold"), + py::arg("sigma"), py::arg("min_score"), py::arg("method"), + py::arg("offset")); + m.def("nms_match", &nms_match, "nms_match (CPU) ", py::arg("dets"), + py::arg("iou_threshold")); + m.def("pixel_group", &pixel_group, "pixel group (CPU) ", py::arg("score"), + py::arg("mask"), py::arg("embedding"), py::arg("kernel_label"), + py::arg("kernel_contour"), py::arg("kernel_region_label"), + py::arg("distance_threshold")); + m.def("contour_expand", &contour_expand, "contour exapnd (CPU) ", + py::arg("kernel_mask"), py::arg("internal_kernel_label"), + py::arg("min_kernel_area"), py::arg("kernel_num")); + m.def("roi_align_forward", &roi_align_forward, "roi_align forward", + py::arg("input"), py::arg("rois"), py::arg("output"), + py::arg("argmax_y"), py::arg("argmax_x"), py::arg("aligned_height"), + py::arg("aligned_width"), py::arg("spatial_scale"), + py::arg("sampling_ratio"), py::arg("pool_mode"), py::arg("aligned")); + m.def("roi_align_backward", &roi_align_backward, "roi_align backward", + py::arg("grad_output"), py::arg("rois"), py::arg("argmax_y"), + py::arg("argmax_x"), py::arg("grad_input"), py::arg("aligned_height"), + py::arg("aligned_width"), py::arg("spatial_scale"), + py::arg("sampling_ratio"), py::arg("pool_mode"), py::arg("aligned")); + m.def("roi_pool_forward", &roi_pool_forward, "roi_pool forward", + py::arg("input"), py::arg("rois"), py::arg("output"), py::arg("argmax"), + py::arg("pooled_height"), py::arg("pooled_width"), + py::arg("spatial_scale")); + m.def("roi_pool_backward", &roi_pool_backward, "roi_pool backward", + py::arg("grad_output"), py::arg("rois"), py::arg("argmax"), + py::arg("grad_input"), py::arg("pooled_height"), + py::arg("pooled_width"), py::arg("spatial_scale")); + m.def("sync_bn_forward_mean", &sync_bn_forward_mean, "sync_bn forward_mean", + py::arg("input"), py::arg("mean")); + m.def("sync_bn_forward_var", &sync_bn_forward_var, "sync_bn forward_var", + py::arg("input"), py::arg("mean"), py::arg("var")); + m.def("sync_bn_forward_output", &sync_bn_forward_output, + "sync_bn forward_output", py::arg("input"), py::arg("mean"), + py::arg("var"), py::arg("weight"), py::arg("bias"), + py::arg("running_mean"), py::arg("running_var"), py::arg("norm"), + py::arg("std"), py::arg("output"), py::arg("eps"), py::arg("momentum"), + py::arg("group_size")); + m.def("sync_bn_backward_param", &sync_bn_backward_param, + "sync_bn backward_param", py::arg("grad_output"), py::arg("norm"), + py::arg("grad_weight"), py::arg("grad_bias")); + m.def("sync_bn_backward_data", &sync_bn_backward_data, + "sync_bn backward_data", py::arg("grad_output"), py::arg("weight"), + py::arg("grad_weight"), py::arg("grad_bias"), py::arg("norm"), + py::arg("std"), py::arg("grad_input")); + m.def("ca_forward", &ca_forward, "ccattention forward", py::arg("t"), + py::arg("f"), py::arg("weight")); + m.def("ca_backward", &ca_backward, "ccattention backward", py::arg("dw"), + py::arg("t"), py::arg("f"), py::arg("dt"), py::arg("df")); + m.def("ca_map_forward", &ca_map_forward, "ccattention map forward", + py::arg("weight"), py::arg("g"), py::arg("out")); + m.def("ca_map_backward", &ca_map_backward, "ccattention map backward", + py::arg("dout"), py::arg("weight"), py::arg("g"), py::arg("dw"), + py::arg("dg")); + m.def("psamask_forward", &psamask_forward, "PSAMASK forward (CPU/CUDA)", + py::arg("input"), py::arg("output"), py::arg("psa_type"), + py::arg("num_"), py::arg("h_feature"), py::arg("w_feature"), + py::arg("h_mask"), py::arg("w_mask"), py::arg("half_h_mask"), + py::arg("half_w_mask")); + m.def("psamask_backward", &psamask_backward, "PSAMASK backward (CPU/CUDA)", + py::arg("grad_output"), py::arg("grad_input"), py::arg("psa_type"), + py::arg("num_"), py::arg("h_feature"), py::arg("w_feature"), + py::arg("h_mask"), py::arg("w_mask"), py::arg("half_h_mask"), + py::arg("half_w_mask")); + m.def("tin_shift_forward", &tin_shift_forward, "tin_shift forward", + py::arg("input"), py::arg("shift"), py::arg("output")); + m.def("tin_shift_backward", &tin_shift_backward, "tin_shift backward", + py::arg("grad_output"), py::arg("shift"), py::arg("grad_input")); + m.def("bottom_pool_forward", &bottom_pool_forward, "Bottom Pool Forward", + py::arg("input"), py::call_guard()); + m.def("bottom_pool_backward", &bottom_pool_backward, "Bottom Pool Backward", + py::arg("input"), py::arg("grad_output"), + py::call_guard()); + m.def("left_pool_forward", &left_pool_forward, "Left Pool Forward", + py::arg("input"), py::call_guard()); + m.def("left_pool_backward", &left_pool_backward, "Left Pool Backward", + py::arg("input"), py::arg("grad_output"), + py::call_guard()); + m.def("right_pool_forward", &right_pool_forward, "Right Pool Forward", + py::arg("input"), py::call_guard()); + m.def("right_pool_backward", &right_pool_backward, "Right Pool Backward", + py::arg("input"), py::arg("grad_output"), + py::call_guard()); + m.def("top_pool_forward", &top_pool_forward, "Top Pool Forward", + py::arg("input"), py::call_guard()); + m.def("top_pool_backward", &top_pool_backward, "Top Pool Backward", + py::arg("input"), py::arg("grad_output"), + py::call_guard()); + m.def("box_iou_rotated", &box_iou_rotated, "IoU for rotated boxes", + py::arg("boxes1"), py::arg("boxes2"), py::arg("ious"), + py::arg("mode_flag"), py::arg("aligned")); + m.def("nms_rotated", &nms_rotated, "NMS for rotated boxes", py::arg("dets"), + py::arg("scores"), py::arg("order"), py::arg("dets_sorted"), + py::arg("iou_threshold"), py::arg("multi_label")); + m.def("roi_align_rotated_forward", &roi_align_rotated_forward, + "roi_align_rotated forward", py::arg("input"), py::arg("rois"), + py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"), + py::arg("spatial_scale"), py::arg("sample_num"), py::arg("aligned"), + py::arg("clockwise")); + m.def("roi_align_rotated_backward", &roi_align_rotated_backward, + "roi_align_rotated backward", py::arg("grad_output"), py::arg("rois"), + py::arg("grad_input"), py::arg("pooled_height"), + py::arg("pooled_width"), py::arg("spatial_scale"), + py::arg("sample_num"), py::arg("aligned"), py::arg("clockwise")); + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, + "forward function of multi-scale deformable attention", + py::arg("value"), py::arg("value_spatial_shapes"), + py::arg("value_level_start_index"), py::arg("sampling_locations"), + py::arg("attention_weights"), py::arg("im2col_step")); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, + "backward function of multi-scale deformable attention", + py::arg("value"), py::arg("value_spatial_shapes"), + py::arg("value_level_start_index"), py::arg("sampling_locations"), + py::arg("attention_weights"), py::arg("grad_output"), + py::arg("grad_value"), py::arg("grad_sampling_loc"), + py::arg("grad_attn_weight"), py::arg("im2col_step")); + m.def("border_align_forward", &border_align_forward, + "forward function of border_align", py::arg("input"), py::arg("boxes"), + py::arg("output"), py::arg("argmax_idx"), py::arg("pool_size")); + m.def("border_align_backward", &border_align_backward, + "backward function of border_align", py::arg("grad_output"), + py::arg("boxes"), py::arg("argmax_idx"), py::arg("grad_input"), + py::arg("pool_size")); +} \ No newline at end of file From 80b6fa7255ecfc46219e02b9b45d5de8c5fa3ea7 Mon Sep 17 00:00:00 2001 From: wondervictor Date: Tue, 8 Jun 2021 16:05:24 +0800 Subject: [PATCH 3/4] optimize criss cross attention --- .editorconfig | 59 --------------------------------------------------- 1 file changed, 59 deletions(-) delete mode 100644 .editorconfig diff --git a/.editorconfig b/.editorconfig deleted file mode 100644 index 163988705d..0000000000 --- a/.editorconfig +++ /dev/null @@ -1,59 +0,0 @@ -[*] -cpp_indent_braces=false -cpp_indent_multi_line_relative_to=innermost_parenthesis -cpp_indent_within_parentheses=indent -cpp_indent_preserve_within_parentheses=false -cpp_indent_case_labels=false -cpp_indent_case_contents=true -cpp_indent_case_contents_when_block=false -cpp_indent_lambda_braces_when_parameter=true -cpp_indent_goto_labels=one_left -cpp_indent_preprocessor=leftmost_column -cpp_indent_access_specifiers=false -cpp_indent_namespace_contents=true -cpp_indent_preserve_comments=false -cpp_new_line_before_open_brace_namespace=ignore -cpp_new_line_before_open_brace_type=ignore -cpp_new_line_before_open_brace_function=ignore -cpp_new_line_before_open_brace_block=ignore -cpp_new_line_before_open_brace_lambda=ignore -cpp_new_line_scope_braces_on_separate_lines=false -cpp_new_line_close_brace_same_line_empty_type=false -cpp_new_line_close_brace_same_line_empty_function=false -cpp_new_line_before_catch=true -cpp_new_line_before_else=true -cpp_new_line_before_while_in_do_while=false -cpp_space_before_function_open_parenthesis=remove -cpp_space_within_parameter_list_parentheses=false -cpp_space_between_empty_parameter_list_parentheses=false -cpp_space_after_keywords_in_control_flow_statements=true -cpp_space_within_control_flow_statement_parentheses=false -cpp_space_before_lambda_open_parenthesis=false -cpp_space_within_cast_parentheses=false -cpp_space_after_cast_close_parenthesis=false -cpp_space_within_expression_parentheses=false -cpp_space_before_block_open_brace=true -cpp_space_between_empty_braces=false -cpp_space_before_initializer_list_open_brace=false -cpp_space_within_initializer_list_braces=true -cpp_space_preserve_in_initializer_list=true -cpp_space_before_open_square_bracket=false -cpp_space_within_square_brackets=false -cpp_space_before_empty_square_brackets=false -cpp_space_between_empty_square_brackets=false -cpp_space_group_square_brackets=true -cpp_space_within_lambda_brackets=false -cpp_space_between_empty_lambda_brackets=false -cpp_space_before_comma=false -cpp_space_after_comma=true -cpp_space_remove_around_member_operators=true -cpp_space_before_inheritance_colon=true -cpp_space_before_constructor_colon=true -cpp_space_remove_before_semicolon=true -cpp_space_after_semicolon=false -cpp_space_remove_around_unary_operator=true -cpp_space_around_binary_operator=insert -cpp_space_around_assignment_operator=insert -cpp_space_pointer_reference_alignment=left -cpp_space_around_ternary_operator=insert -cpp_wrap_preserve_blocks=one_liners From 8a42569b7564de6be367e18eda3b9b2d76b4fc1e Mon Sep 17 00:00:00 2001 From: wondervictor Date: Wed, 9 Jun 2021 10:25:25 +0800 Subject: [PATCH 4/4] fix lint --- mmcv/ops/csrc/pytorch/pybind.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index b36b787630..0b88e55658 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -462,4 +462,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "backward function of border_align", py::arg("grad_output"), py::arg("boxes"), py::arg("argmax_idx"), py::arg("grad_input"), py::arg("pool_size")); -} \ No newline at end of file +}