diff --git a/paddle/fluid/operators/conv_op_mlu.cc b/paddle/fluid/operators/conv_op_mlu.cc index af1f3516bd1a6..0e0ed82e8798a 100644 --- a/paddle/fluid/operators/conv_op_mlu.cc +++ b/paddle/fluid/operators/conv_op_mlu.cc @@ -436,6 +436,8 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel { Tensor output_grad_tensor(output_grad->type()); const std::vector perm_to_nhwc = {0, 2, 3, 1}; const std::vector perm_to_nchw = {0, 3, 1, 2}; + const std::vector perm_hwcm_to_mchw = {3, 2, 0, 1}; + const std::vector perm_mchw_to_hwcm = {2, 3, 1, 0}; if (channel_last) { input_tensor.ShareDataWith(*input); output_grad_tensor.ShareDataWith(*output_grad); @@ -462,10 +464,12 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel { auto filter_grad_dims = filter_grad->dims(); Tensor temp_filter_grad(filter_grad->type()); - temp_filter_grad.mutable_data({filter_grad_dims[0], - filter_grad_dims[2], - filter_grad_dims[3], - filter_grad_dims[1]}, + // Details about setting diff_w hwcn for better performance, see the CNNL + // documentation. + temp_filter_grad.mutable_data({filter_grad_dims[perm_mchw_to_hwcm[0]], + filter_grad_dims[perm_mchw_to_hwcm[1]], + filter_grad_dims[perm_mchw_to_hwcm[2]], + filter_grad_dims[perm_mchw_to_hwcm[3]]}, ctx.GetPlace()); cnnlDataType_t tensor_dtype = ToCnnlDataType(); @@ -474,7 +478,7 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel { MLUCnnlTensorDesc out_grad_desc( output_grad_tensor, data_layout, tensor_dtype); MLUCnnlTensorDesc temp_filter_grad_desc( - temp_filter_grad, data_layout, tensor_dtype); + temp_filter_grad, CNNL_LAYOUT_HWCN, tensor_dtype); MLUCnnlConvolutionDesc conv_desc(in_dims_size, paddings.data(), @@ -492,9 +496,9 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel { temp_filter_grad_desc.get(), GetBasePtr(&temp_filter_grad)); - // transpose filter_grad from MHWC to MCHW + // transpose filter_grad from HWCM to MCHW TransposeFromMLUTensor(ctx, - perm_to_nchw, + perm_hwcm_to_mchw, &temp_filter_grad, filter_grad, false /*need_reshape_or_alloc*/);