From 03056c2d2cc74fcf072f402c6d3a7fe66fa0a4e5 Mon Sep 17 00:00:00 2001 From: brianlcy123 Date: Tue, 16 Jan 2024 16:08:41 +0800 Subject: [PATCH] [Kunlunxin] Add Flip/MaxPool2dWithIndices/MaxPool2dGrad (#865) * [kunlunxin] fix ci - add missing file * [CI][kunlunxin]fix sub dtype * [CI][kunlunxin]fix reduce_mean&reduce_sum;add more test cases * [KUNLUNXIN] add flip/max_pool2d_with_indices/max_pool2d_grad --- impl/kunlunxin/convert_config.yaml | 6 ++++ impl/kunlunxin/functions/basic_op.cpp | 42 +++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/impl/kunlunxin/convert_config.yaml b/impl/kunlunxin/convert_config.yaml index b1c1af376..f2bf026a7 100644 --- a/impl/kunlunxin/convert_config.yaml +++ b/impl/kunlunxin/convert_config.yaml @@ -1,2 +1,8 @@ - common_config: dtype: (float64)->float32, (int64)->int32 + +- diopiMaxPool2dWithIndices: + layout: NCHW + +- diopiMaxPool2dBackward: + layout: NCHW diff --git a/impl/kunlunxin/functions/basic_op.cpp b/impl/kunlunxin/functions/basic_op.cpp index f44f25afb..a62bd02db 100644 --- a/impl/kunlunxin/functions/basic_op.cpp +++ b/impl/kunlunxin/functions/basic_op.cpp @@ -514,5 +514,47 @@ DIOPI_API diopiError_t diopiCat(diopiContextHandle_t ctx, diopiTensorHandle_t ou return diopiSuccess; } +DIOPI_API diopiError_t diopiFlip(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiSize_t dims) { + xdnn::Context* ctx_xpu = impl::kunlunxin::set_cur_ctx(ctx); + xdnn_pytorch::Tensor _in = impl::kunlunxin::build_xtorch_tensor(input); + xdnn_pytorch::Tensor _out = impl::kunlunxin::build_xtorch_tensor(out); + xtorch_vec _dims = impl::kunlunxin::build_xtorch_vec(dims); + + DIOPI_CALL_XDNN(xdnn_pytorch::flip(ctx_xpu, _in, _dims, _out)); + return diopiSuccess; +} + +DIOPI_API diopiError_t diopiMaxPool2dBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiConstTensorHandle_t grad_output, + diopiConstTensorHandle_t input, diopiSize_t kernel_size, diopiSize_t stride, diopiSize_t padding, + diopiSize_t dilation, bool ceil_mode, diopiConstTensorHandle_t indices) { + xdnn::Context* ctx_xpu = impl::kunlunxin::set_cur_ctx(ctx); + xdnn_pytorch::Tensor _in = impl::kunlunxin::build_xtorch_tensor(input); + xdnn_pytorch::Tensor _grad_in = impl::kunlunxin::build_xtorch_tensor(grad_input); + xdnn_pytorch::Tensor _grad_out = impl::kunlunxin::build_xtorch_tensor(grad_output); + xdnn_pytorch::Tensor _indices = impl::kunlunxin::build_xtorch_tensor(indices); + xtorch_vec _kernel_size = impl::kunlunxin::build_xtorch_vec(kernel_size); + xtorch_vec _stride = impl::kunlunxin::build_xtorch_vec(stride); + xtorch_vec _padding = impl::kunlunxin::build_xtorch_vec(padding); + xtorch_vec _dilation = impl::kunlunxin::build_xtorch_vec(dilation); + + DIOPI_CALL_XDNN( + xdnn_pytorch::max_pool2d_with_indices_backward(ctx_xpu, _grad_out, _in, _kernel_size, _stride, _padding, _dilation, ceil_mode, _indices, _grad_in)); + return diopiSuccess; +} + +DIOPI_API diopiError_t diopiMaxPool2dWithIndices(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t indices, diopiConstTensorHandle_t input, + diopiSize_t kernel_size, diopiSize_t stride, diopiSize_t padding, diopiSize_t dilation, bool ceil_mode) { + xdnn::Context* ctx_xpu = impl::kunlunxin::set_cur_ctx(ctx); + xdnn_pytorch::Tensor _in = impl::kunlunxin::build_xtorch_tensor(input); + xdnn_pytorch::Tensor _out = impl::kunlunxin::build_xtorch_tensor(out); + xdnn_pytorch::Tensor _indices = impl::kunlunxin::build_xtorch_tensor(indices); + xtorch_vec _kernel_size = impl::kunlunxin::build_xtorch_vec(kernel_size); + xtorch_vec _stride = impl::kunlunxin::build_xtorch_vec(stride); + xtorch_vec _padding = impl::kunlunxin::build_xtorch_vec(padding); + xtorch_vec _dilation = impl::kunlunxin::build_xtorch_vec(dilation); + DIOPI_CALL_XDNN(xdnn_pytorch::max_pool2d_with_indices(ctx_xpu, _in, _kernel_size, _stride, _padding, _dilation, ceil_mode, _out, _indices)); + return diopiSuccess; +} + } // namespace kunlunxin } // namespace impl