Skip to content

Commit

Permalink
[Kunlunxin] Add Flip/MaxPool2dWithIndices/MaxPool2dGrad (#865)
Browse files Browse the repository at this point in the history
* [kunlunxin] fix ci - add missing file

* [CI][kunlunxin]fix sub dtype

* [CI][kunlunxin]fix reduce_mean&reduce_sum;add more test cases

* [KUNLUNXIN] add flip/max_pool2d_with_indices/max_pool2d_grad
  • Loading branch information
brianlcy123 authored Jan 16, 2024
1 parent 42ae84b commit 03056c2
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
6 changes: 6 additions & 0 deletions impl/kunlunxin/convert_config.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
- common_config:
dtype: (float64)->float32, (int64)->int32

- diopiMaxPool2dWithIndices:
layout: NCHW

- diopiMaxPool2dBackward:
layout: NCHW
42 changes: 42 additions & 0 deletions impl/kunlunxin/functions/basic_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,5 +514,47 @@ DIOPI_API diopiError_t diopiCat(diopiContextHandle_t ctx, diopiTensorHandle_t ou
return diopiSuccess;
}

DIOPI_API diopiError_t diopiFlip(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiSize_t dims) {
xdnn::Context* ctx_xpu = impl::kunlunxin::set_cur_ctx(ctx);
xdnn_pytorch::Tensor _in = impl::kunlunxin::build_xtorch_tensor(input);
xdnn_pytorch::Tensor _out = impl::kunlunxin::build_xtorch_tensor(out);
xtorch_vec _dims = impl::kunlunxin::build_xtorch_vec(dims);

DIOPI_CALL_XDNN(xdnn_pytorch::flip(ctx_xpu, _in, _dims, _out));
return diopiSuccess;
}

DIOPI_API diopiError_t diopiMaxPool2dBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiConstTensorHandle_t grad_output,
diopiConstTensorHandle_t input, diopiSize_t kernel_size, diopiSize_t stride, diopiSize_t padding,
diopiSize_t dilation, bool ceil_mode, diopiConstTensorHandle_t indices) {
xdnn::Context* ctx_xpu = impl::kunlunxin::set_cur_ctx(ctx);
xdnn_pytorch::Tensor _in = impl::kunlunxin::build_xtorch_tensor(input);
xdnn_pytorch::Tensor _grad_in = impl::kunlunxin::build_xtorch_tensor(grad_input);
xdnn_pytorch::Tensor _grad_out = impl::kunlunxin::build_xtorch_tensor(grad_output);
xdnn_pytorch::Tensor _indices = impl::kunlunxin::build_xtorch_tensor(indices);
xtorch_vec _kernel_size = impl::kunlunxin::build_xtorch_vec(kernel_size);
xtorch_vec _stride = impl::kunlunxin::build_xtorch_vec(stride);
xtorch_vec _padding = impl::kunlunxin::build_xtorch_vec(padding);
xtorch_vec _dilation = impl::kunlunxin::build_xtorch_vec(dilation);

DIOPI_CALL_XDNN(
xdnn_pytorch::max_pool2d_with_indices_backward(ctx_xpu, _grad_out, _in, _kernel_size, _stride, _padding, _dilation, ceil_mode, _indices, _grad_in));
return diopiSuccess;
}

DIOPI_API diopiError_t diopiMaxPool2dWithIndices(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t indices, diopiConstTensorHandle_t input,
diopiSize_t kernel_size, diopiSize_t stride, diopiSize_t padding, diopiSize_t dilation, bool ceil_mode) {
xdnn::Context* ctx_xpu = impl::kunlunxin::set_cur_ctx(ctx);
xdnn_pytorch::Tensor _in = impl::kunlunxin::build_xtorch_tensor(input);
xdnn_pytorch::Tensor _out = impl::kunlunxin::build_xtorch_tensor(out);
xdnn_pytorch::Tensor _indices = impl::kunlunxin::build_xtorch_tensor(indices);
xtorch_vec _kernel_size = impl::kunlunxin::build_xtorch_vec(kernel_size);
xtorch_vec _stride = impl::kunlunxin::build_xtorch_vec(stride);
xtorch_vec _padding = impl::kunlunxin::build_xtorch_vec(padding);
xtorch_vec _dilation = impl::kunlunxin::build_xtorch_vec(dilation);
DIOPI_CALL_XDNN(xdnn_pytorch::max_pool2d_with_indices(ctx_xpu, _in, _kernel_size, _stride, _padding, _dilation, ceil_mode, _out, _indices));
return diopiSuccess;
}

} // namespace kunlunxin
} // namespace impl

0 comments on commit 03056c2

Please sign in to comment.