From 081aa58f81a536cc2a6168c51bdfb3b98bb5a17d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20G=C5=82omski?= Date: Tue, 11 May 2021 04:33:51 +0200 Subject: [PATCH] [v1.x][BUGFIX] Implement oneDNN deconvolution primitives to deconvolution 2D (#20107) * Use mkldnn deconvolution primitive in deconvolution * Apply clang-format * Refactor deconvolution version 1 * Refactor deconvolution version 2 and use permute_axes in IOLogicalSwapDesc * Refactor deconvolution version 3 * Enable Deconvolution2D test * Fix sanity * Fix windows builds * Fix deconvolution with bias test --- .../nn/mkldnn/mkldnn_deconvolution-inl.h | 377 ++++++++++ .../nn/mkldnn/mkldnn_deconvolution.cc | 696 +++++++----------- tests/python/mkl/test_mkldnn.py | 3 +- tests/python/unittest/test_operator.py | 34 +- 4 files changed, 668 insertions(+), 442 deletions(-) create mode 100644 src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h b/src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h new file mode 100644 index 000000000000..b51ec2a85650 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_deconvolution-inl.h + * Naming convention: + * ________ + * (src) data --->|Deconv| + * weights --->| FWD |---> out (dst) + * bias --->|______| + * ________ + * (diff_src) data_grad <---|Deconv|<--- out_grad (diff_dst) + * (diff_weight) weights_grad <---| BWD |<--- data (src) + * (diff_bias) bias_grad <---| |<--- weight + * |______|<--- bias + * + * "out" in this (and .cc) file will always refer to the output of Deconv FWD and + * "out_grad" to its gradient. The corresponding MKLDNN names are in parentheses. + */ +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_DECONVOLUTION_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_DECONVOLUTION_INL_H_ + +#if MXNET_USE_MKLDNN == 1 +#include +#include +#include + +#include "../deconvolution-inl.h" +#include "./mkldnn_base-inl.h" +#include "./mkldnn_ops-inl.h" + +namespace mxnet { +namespace op { + +using deconv_fwd_t = mkldnn::deconvolution_forward; +using deconv_fwd_pd_t = mkldnn::deconvolution_forward::primitive_desc; + +using deconv_bwd_data_t = mkldnn::deconvolution_backward_data; +using deconv_bwd_data_pd_t = mkldnn::deconvolution_backward_data::primitive_desc; + +using deconv_bwd_weights_t = mkldnn::deconvolution_backward_weights; +using deconv_bwd_weights_pd_t = mkldnn::deconvolution_backward_weights::primitive_desc; + + + +// Swaps the logical order of dimensions that in plain format would correspond to input and output +// channels (for example: oihw => iohw, iohw => oihw, goihw => giohw). +inline mkldnn::memory::desc IOLogicalSwapDesc(const mkldnn::memory::desc &desc, + const uint32_t num_group) { + std::vector order(desc.data.ndims); + std::iota(std::begin(order), std::end(order), 0); + const int offset = static_cast(num_group > 1); + std::swap(order[offset + 0], order[offset + 1]); + return desc.permute_axes(order); +} + +// Applies IOLogicalSwapDesc to MKLDNN memory of arr +inline void IOLogicalSwapMKLDNNMem(const NDArray &arr, const uint32_t num_group) { + mkldnn::memory::desc desc; + if (arr.IsMKLDNNData()) { + desc = arr.GetMKLDNNData()->get_desc(); + } else { + // GetMKLDNNData won't take groups into account when creating mkldnn::memory, we need to use + // descriptor from GetWeightDesc but with default format + const auto &temp = GetWeightDesc(arr, num_group); + desc = mkldnn::memory::desc( + temp.dims(), temp.data_type(), + static_cast(GetDefaultFormat(temp.data.ndims))); + } + const_cast(arr).UpdateMKLDNNMemDesc(IOLogicalSwapDesc(desc, num_group)); +} + +// Version of GetWeightsDesc for deconvolution (with swap) +inline mkldnn::memory::desc GetDeconvWeightsDesc(const NDArray &weights, const uint32_t num_group) { + return IOLogicalSwapDesc(GetWeightDesc(weights, num_group), num_group); +} + + + +class MKLDNNDeconvFwd { + public: + struct Tensors { + Tensors(const NDArray &data, const NDArray &weights, const NDArray *const bias, + const NDArray &out); + Tensors(const bool no_bias, const std::vector &inputs, + const std::vector &outputs); + + const NDArray &data; + const NDArray &weights; + const NDArray *const bias; + const NDArray &out; + }; + + static MKLDNNDeconvFwd &GetCached(const DeconvolutionParam ¶m, const Tensors &tensors); + static std::shared_ptr CreatePrimitiveDesc(const DeconvolutionParam ¶m, + const Tensors &tensors); + + MKLDNNDeconvFwd(const DeconvolutionParam ¶m, const Tensors &tensors); + void ControlWeightsFormat(const uint32_t num_group, const bool is_train, + const NDArray &weights) const; + void Execute(const uint32_t num_group, const OpReqType req, const Tensors &tensors) const; + + private: + const mkldnn::memory *DataMem(const NDArray &data) const; + const mkldnn::memory *WeightsMem(const uint32_t num_group, const NDArray &weights) const; + const mkldnn::memory *BiasMem(const NDArray &bias) const; + + mkldnn_output_t OutMem(const OpReqType req, const NDArray &out) const; + + private: + std::shared_ptr fwd; + std::shared_ptr fwd_pd; +}; + + +MKLDNNDeconvFwd::Tensors::Tensors(const bool no_bias, const std::vector &inputs, + const std::vector &outputs) + : data(inputs[deconv::kData]), + weights(inputs[deconv::kWeight]), + bias(no_bias ? nullptr : &inputs[deconv::kBias]), + out(outputs[deconv::kOut]) {} + +MKLDNNDeconvFwd::Tensors::Tensors(const NDArray &data, const NDArray &weights, + const NDArray *const bias, const NDArray &out) + : data(data), weights(weights), bias(bias), out(out) {} + +MKLDNNDeconvFwd::MKLDNNDeconvFwd(const DeconvolutionParam ¶m, const Tensors &tensors) + : fwd_pd(CreatePrimitiveDesc(param, tensors)) { + fwd = std::make_shared(*fwd_pd); +} + +inline const mkldnn::memory *MKLDNNDeconvFwd::DataMem(const NDArray &data) const { + return data.GetMKLDNNDataReorder(fwd_pd->src_desc()); +} + +inline const mkldnn::memory *MKLDNNDeconvFwd::WeightsMem(const uint32_t num_group, + const NDArray &weights) const { + return GetWeights(weights, fwd_pd->weights_desc(), num_group); +} + +inline const mkldnn::memory *MKLDNNDeconvFwd::BiasMem(const NDArray &bias) const { + return bias.GetMKLDNNData(); +} + +inline mkldnn_output_t MKLDNNDeconvFwd::OutMem(const OpReqType req, const NDArray &out) const { + return CreateMKLDNNMem(out, fwd_pd->dst_desc(), req); +} + + + +class MKLDNNDeconvBwd { + public: + struct ReadTensors { + ReadTensors(const bool no_bias, const std::vector &inputs); + const NDArray &data; + const NDArray &weights; + const NDArray *const bias; + const NDArray &out_grad; + }; + struct WriteTensors { + WriteTensors(const bool no_bias, const std::vector &outputs); + const NDArray &data_grad; + const NDArray &weights_grad; + const NDArray *const bias_grad; + }; + + static MKLDNNDeconvBwd &GetCached(const DeconvolutionParam ¶m, + const ReadTensors &read_tensors); + + static std::shared_ptr CreateDataPrimitiveDesc( + const DeconvolutionParam ¶m, const ReadTensors &read_tensors, + const deconv_fwd_pd_t &fwd_pd); + + static std::shared_ptr CreateWeightsPrimitiveDesc( + const DeconvolutionParam ¶m, const ReadTensors &read_tensors, + const deconv_fwd_pd_t &fwd_pd); + + MKLDNNDeconvBwd(const DeconvolutionParam ¶m, const ReadTensors &read_tensors); + + void Execute(const uint32_t num_group, const std::vector &req, + const ReadTensors &read_tensors, const WriteTensors &write_tensors) const; + + private: + void IOSwapWeightsTensors(const uint32_t num_group, const std::vector &req, + const NDArray &weights, const NDArray &weights_grad) const; + + // returns the output gradient memory used to calculate the data (input) gradient, + // which might be reused when calculating the gradient of weights + const mkldnn::memory *ScheduleBwdData(const uint32_t num_group, const OpReqType req, + const ReadTensors &read_tensors, + const WriteTensors &write_tensors) const; + + void ScheduleBwdWeights(const uint32_t num_group, const std::vector &req, + const ReadTensors &read_tensors, const WriteTensors &write_tensors, + const mkldnn::memory *const out_grad_mem) const; + + const mkldnn::memory *DataMem(const NDArray &data) const; + const mkldnn::memory *WeightsMem(const uint32_t num_group, const NDArray &weights) const; + + // for calculating the gradient of data (input) + const mkldnn::memory *OutGradMem(const NDArray &out_grad) const; + // for calculating the gradient of weights + const mkldnn::memory *OutGradMem(const NDArray &out_grad, + const mkldnn::memory *const out_grad_mem) const; + + mkldnn_output_t DataGradMem(const OpReqType req, const NDArray &data_grad) const; + mkldnn_output_t WeightsGradMem(const uint32_t num_group, const OpReqType req, + const NDArray &weights_grad) const; + mkldnn_output_t BiasGradMem(const OpReqType req, const NDArray *const bias) const; + + std::shared_ptr bwd_data_pd; + std::shared_ptr bwd_weights_pd; + std::shared_ptr bwd_data; + std::shared_ptr bwd_weights; +}; + + +MKLDNNDeconvBwd::ReadTensors::ReadTensors(const bool no_bias, const std::vector &inputs) + : data(inputs[deconv::kData + 1]), + weights(inputs[deconv::kWeight + 1]), + bias(no_bias ? nullptr : &inputs[deconv::kBias + 1]), + out_grad(inputs[deconv::kOut]) {} + +MKLDNNDeconvBwd::WriteTensors::WriteTensors(const bool no_bias, const std::vector &outputs) + : data_grad(outputs[deconv::kData]), + weights_grad(outputs[deconv::kWeight]), + bias_grad(no_bias ? nullptr : &outputs[deconv::kBias]) {} + +MKLDNNDeconvBwd::MKLDNNDeconvBwd(const DeconvolutionParam ¶m, const ReadTensors &read_tensors) { + const auto &fwd_pd = MKLDNNDeconvFwd::CreatePrimitiveDesc( + param, MKLDNNDeconvFwd::Tensors(read_tensors.data, read_tensors.weights, read_tensors.bias, + read_tensors.out_grad)); + bwd_data_pd = CreateDataPrimitiveDesc(param, read_tensors, *fwd_pd); + bwd_weights_pd = CreateWeightsPrimitiveDesc(param, read_tensors, *fwd_pd); + bwd_data = std::make_shared(*bwd_data_pd); + bwd_weights = std::make_shared(*bwd_weights_pd); +} + +inline void MKLDNNDeconvBwd::IOSwapWeightsTensors(const uint32_t num_group, + const std::vector &req, + const NDArray &weights, + const NDArray &weights_grad) const { + if (req[deconv::kData]) { + IOLogicalSwapMKLDNNMem(weights, num_group); + } + if (req[deconv::kWeight] || (req.size() < deconv::kBias && req[deconv::kBias])) { + IOLogicalSwapMKLDNNMem(weights_grad, num_group); + } +} + +inline const mkldnn::memory *MKLDNNDeconvBwd::DataMem(const NDArray &data) const { + return data.GetMKLDNNDataReorder(bwd_weights_pd->src_desc()); +} + +inline const mkldnn::memory *MKLDNNDeconvBwd::WeightsMem(const uint32_t num_group, + const NDArray &weights) const { + return GetWeights(weights, bwd_data_pd->weights_desc(), num_group); +} + +inline const mkldnn::memory *MKLDNNDeconvBwd::OutGradMem(const NDArray &out_grad) const { + return out_grad.GetMKLDNNDataReorder(bwd_data_pd->diff_dst_desc()); +} + +inline const mkldnn::memory *MKLDNNDeconvBwd::OutGradMem( + const NDArray &out_grad, const mkldnn::memory *const out_grad_mem) const { + return (out_grad_mem && out_grad_mem->get_desc() == bwd_weights_pd->diff_dst_desc()) + ? out_grad_mem + : out_grad.GetMKLDNNDataReorder(bwd_weights_pd->diff_dst_desc()); +} + +inline mkldnn_output_t MKLDNNDeconvBwd::DataGradMem(const OpReqType req, + const NDArray &data_grad) const { + return CreateMKLDNNMem(data_grad, bwd_data_pd->diff_src_desc(), req); +} + +inline mkldnn_output_t MKLDNNDeconvBwd::WeightsGradMem(const uint32_t num_group, + const OpReqType req, + const NDArray &weights_grad) const { + // CreateMKLDNNWeightGrad always creates a new tensor as IsDefaultFormat always fails (because + // of the logical swap - explained in MKLDNNDeconvFwd::Execute). We try to reuse weights_grad + // memory (which, when not swapped, is always in default format), so here we check if after a + // swap, weights_md will have a default format + const auto &weights_md = bwd_weights_pd->diff_weights_desc(); + if (req == OpReqType::kWriteTo && IsDefaultFormat(IOLogicalSwapDesc(weights_md, num_group))) { + return {OutDataOp::Noop, const_cast(weights_grad).CreateMKLDNNData(weights_md)}; + } + return CreateMKLDNNWeightGrad(weights_grad, weights_md, req); +} + +inline mkldnn_output_t MKLDNNDeconvBwd::BiasGradMem(const OpReqType req, + const NDArray *const bias) const { + return bias ? CreateMKLDNNMem(*bias, bwd_weights_pd->diff_bias_desc(), req) + : mkldnn_output_t(OutDataOp::Noop, nullptr); +} + + + +// Utility class for creating operation descriptors of deconvolution primitives +class DeconvDescCreator { + public: + DeconvDescCreator(const DeconvolutionParam ¶m, const NDArray &data, const NDArray &weights, + const NDArray *const bias, const NDArray &out); + + // Imposes plain formats on memory descriptors with padding (so the next selected implementation + // will pass CheckImplSizeReq). After calling this method, new primitive descriptor (with new + // operator descriptor) should be created, which should select an implementation with matching + // size requirements. + // data_size, weights_size, out_size - size requirements of current implementation + // Returns whether successfully imposed a plain format on any of the data, weights, and output + // memory descriptors. + bool ImposePlainWherePadding(const size_t data_size, const size_t weights_size, + const size_t out_size); + bool CheckImplSizeReq(const size_t data_size, const size_t weights_size, + const size_t out_size) const; + + deconv_fwd_t::desc CreateFwdDesc() const; + deconv_bwd_data_t::desc CreateBwdDataDesc() const; + deconv_bwd_weights_t::desc CreateBwdWeightsDesc() const; + + private: + mkldnn::memory::desc data_md; + mkldnn::memory::desc weights_md; + mkldnn::memory::desc bias_md; + mkldnn::memory::desc out_md; + + mkldnn::memory::dims strides; + mkldnn::memory::dims padding; + mkldnn::memory::dims dilates; +}; + + +inline bool DeconvDescCreator::CheckImplSizeReq(const size_t data_size, const size_t weights_size, + const size_t out_size) const { + // MKLDNN introduced padded formats since 0.15 which require more memory + // compared to the actual size of the tensor. Currently, MKLDNN operators + // still reuse memory from memory planning, so here we need to accept only a + // kernel that has the expected memory size requirements (which is suboptimal) + return (data_size == GetMemDescSize(data_md) && weights_size == GetMemDescSize(weights_md) && + out_size == GetMemDescSize(out_md)); +} + +inline deconv_fwd_t::desc DeconvDescCreator::CreateFwdDesc() const { + return deconv_fwd_t::desc(mkldnn::prop_kind::forward_training, + mkldnn::algorithm::deconvolution_direct, data_md, weights_md, bias_md, + out_md, strides, dilates, padding, padding); +} + +inline deconv_bwd_data_t::desc DeconvDescCreator::CreateBwdDataDesc() const { + return deconv_bwd_data_t::desc(mkldnn::algorithm::deconvolution_direct, data_md, weights_md, + out_md, strides, dilates, padding, padding); +} + +inline deconv_bwd_weights_t::desc DeconvDescCreator::CreateBwdWeightsDesc() const { + return deconv_bwd_weights_t::desc(mkldnn::algorithm::deconvolution_direct, data_md, weights_md, + bias_md, out_md, strides, dilates, padding, padding); +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_DECONVOLUTION_INL_H__ diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc index 65bf93298b95..7678567d95c8 100644 --- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -19,505 +19,339 @@ /*! * \file mkldnn_deconvolution.cc - * \brief */ #if MXNET_USE_MKLDNN == 1 -#include "../deconvolution-inl.h" -#include "./mkldnn_base-inl.h" -#include "./mkldnn_ops-inl.h" +#include "./mkldnn_deconvolution-inl.h" namespace mxnet { namespace op { -bool SupportMKLDNNDeconv(const DeconvolutionParam ¶ms, - const NDArray &input) { - if (params.kernel.ndim() != 2) return false; - return (input.dtype() == mshadow::kFloat32 || input.dtype() == mshadow::kBfloat16) - && input.shape().ndim() == 4; +bool SupportMKLDNNDeconv(const DeconvolutionParam ¶ms, const NDArray &input) { + return params.kernel.ndim() == 2 && input.shape().ndim() == 4 && + (input.dtype() == mshadow::kFloat32 || input.dtype() == mshadow::kBfloat16); } -static inline mkldnn::memory::desc GetBiasDesc(mkldnn::memory::desc md) { - mkldnn::memory::dims dims(1); - // This is deconvolution on 4D data. The second dimension is the channel. - dims[0] = md.data.dims[1]; - return mkldnn::memory::desc( - dims, static_cast(md.data.data_type), - mkldnn::memory::format_tag::any); -} - -std::shared_ptr GetDeconvBwd_( - const mkldnn::memory::desc &data_md, const mkldnn::memory::desc &weights_md, - bool has_bias, const mkldnn::memory::desc &out_md, - const mkldnn::engine &engine, const mkldnn::memory::dims &strides, - const mkldnn::memory::dims &padding, const mkldnn::memory::dims &dilates) { - // MKL-DNN introduced padded formats since 0.15 which require more memory - // compared to the actual size of the tensor. Currently, MKL-DNN operators - // still reuse memory from memory planning, so here we need to select a - // suboptimal kernel for computation that has the expected memory size requirements - if (!has_bias) { - mkldnn::convolution_forward::desc desc( - mkldnn::prop_kind::forward_training, - mkldnn::algorithm::convolution_direct, out_md, weights_md, data_md, - strides, dilates, padding, padding); - auto deconv_pd = - std::make_shared(desc, - engine); - while (deconv_pd->dst_desc().get_size() != GetMemDescSize(data_md) || - deconv_pd->src_desc().get_size() != GetMemDescSize(out_md) || - deconv_pd->weights_desc().get_size() != GetMemDescSize(weights_md)) { - CHECK(deconv_pd->next_impl()) << "No implementation"; - } - return deconv_pd; - } else { - auto bias_md = GetBiasDesc(data_md); - mkldnn::convolution_forward::desc desc( - mkldnn::prop_kind::forward_training, - mkldnn::algorithm::convolution_direct, out_md, weights_md, bias_md, - data_md, strides, dilates, padding, padding); - auto deconv_pd = - std::make_shared(desc, - engine); - while (deconv_pd->dst_desc().get_size() != GetMemDescSize(data_md) || - deconv_pd->src_desc().get_size() != GetMemDescSize(out_md) || - deconv_pd->weights_desc().get_size() != GetMemDescSize(weights_md)) { - CHECK(deconv_pd->next_impl()) << "No implementation"; - } - return deconv_pd; - } -} - -std::shared_ptr -GetDeconvFwdImpl(const DeconvolutionParam ¶m, const NDArray &data, - const NDArray &weights, bool has_bias, const NDArray &output) { - auto data_md = GetMemDesc(data); - auto weight_md = GetWeightDesc(weights, param.num_group); - auto out_md = GetMemDesc(output); - auto engine = CpuEngine::Get()->get_engine(); - CHECK_GE(param.stride.ndim(), 2); - CHECK_GE(param.pad.ndim(), 2); - CHECK_GE(param.dilate.ndim(), 2); - mkldnn::memory::dims strides{0, 0}; - strides[0] = param.stride[0]; - strides[1] = param.stride[1]; - mkldnn::memory::dims padding{0, 0}; - padding[0] = param.pad[0]; - padding[1] = param.pad[1]; - mkldnn::memory::dims dilate{0, 0}; - dilate[0] = param.dilate[0] - 1; - dilate[1] = param.dilate[1] - 1; - auto bwd_pd = GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine, - strides, padding, dilate); - mkldnn::convolution_backward_data::desc desc( - mkldnn::algorithm::convolution_direct, out_md, weight_md, data_md, - strides, dilate, padding, padding); - auto deconv_pd = - std::make_shared( - desc, engine, *bwd_pd); - // MKL-DNN introduced padded formats since 0.15 which require more memory - // compared to the actual size of the tensor. Currently, MKL-DNN operators - // still reuse memory from memory planning, so here we need to select a - // suboptimal kernel for computation that has the expected memory size requirements - while (deconv_pd->diff_dst_desc().get_size() != GetMemDescSize(data_md) || - deconv_pd->diff_src_desc().get_size() != GetMemDescSize(out_md) || - deconv_pd->weights_desc().get_size() != GetMemDescSize(weight_md)) { - CHECK(deconv_pd->next_impl()) << "No implementation"; - } - return deconv_pd; -} - -std::shared_ptr -GetDeconvBwdDataImpl(const DeconvolutionParam ¶m, const NDArray &data, - const NDArray &weights, bool has_bias, - const NDArray &output) { - auto data_md = GetMemDesc(data); - auto weight_md = GetWeightDesc(weights, param.num_group); - auto out_md = GetMemDesc(output); - auto engine = CpuEngine::Get()->get_engine(); - CHECK_GE(param.stride.ndim(), 2); - CHECK_GE(param.pad.ndim(), 2); - CHECK_GE(param.dilate.ndim(), 2); - mkldnn::memory::dims strides{0, 0}; - strides[0] = param.stride[0]; - strides[1] = param.stride[1]; - mkldnn::memory::dims padding{0, 0}; - padding[0] = param.pad[0]; - padding[1] = param.pad[1]; - mkldnn::memory::dims dilate{0, 0}; - dilate[0] = param.dilate[0] - 1; - dilate[1] = param.dilate[1] - 1; - return GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine, strides, - padding, dilate); -} - -std::shared_ptr -GetDeconvBwdWeightsImpl( - const DeconvolutionParam ¶m, const NDArray &data, - const NDArray &weights, bool has_bias, const NDArray &output, - const mkldnn::convolution_forward::primitive_desc &fwd_pd) { - auto data_md = GetMemDesc(data); - auto weight_md = GetWeightDesc(weights, param.num_group); - auto out_md = GetMemDesc(output); - auto engine = CpuEngine::Get()->get_engine(); - CHECK_GE(param.stride.ndim(), 2); - CHECK_GE(param.pad.ndim(), 2); - CHECK_GE(param.dilate.ndim(), 2); - mkldnn::memory::dims strides{0, 0}; - strides[0] = param.stride[0]; - strides[1] = param.stride[1]; - mkldnn::memory::dims padding{0, 0}; - padding[0] = param.pad[0]; - padding[1] = param.pad[1]; - mkldnn::memory::dims dilate{0, 0}; - dilate[0] = param.dilate[0] - 1; - dilate[1] = param.dilate[1] - 1; - - // MKL-DNN introduced padded formats since 0.15 which require more memory - // compared to the actual size of the tensor. Currently, MKL-DNN operators - // still reuse memory from memory planning, so here we need to select a - // suboptimal kernel for computation that has the expected memory size requirements - if (!has_bias) { - mkldnn::convolution_backward_weights::desc desc( - mkldnn::algorithm::convolution_direct, out_md, weight_md, data_md, - strides, dilate, padding, padding); - auto deconv_pd = - std::make_shared( - desc, engine, fwd_pd); - while (deconv_pd->diff_dst_desc().get_size() != GetMemDescSize(data_md) || - deconv_pd->src_desc().get_size() != GetMemDescSize(out_md) || - deconv_pd->diff_weights_desc().get_size() != - GetMemDescSize(weight_md)) { - CHECK(deconv_pd->next_impl()) << "No implementation"; - } - return deconv_pd; - } else { - auto bias_md = GetBiasDesc(data_md); - mkldnn::convolution_backward_weights::desc desc( - mkldnn::algorithm::convolution_direct, out_md, weight_md, bias_md, - data_md, strides, dilate, padding, padding); - auto deconv_pd = - std::make_shared( - desc, engine, fwd_pd); - while (deconv_pd->diff_dst_desc().get_size() != GetMemDescSize(data_md) || - deconv_pd->src_desc().get_size() != GetMemDescSize(out_md) || - deconv_pd->diff_weights_desc().get_size() != - GetMemDescSize(weight_md)) { - CHECK(deconv_pd->next_impl()) << "No implementation"; - } - return deconv_pd; - } -} -class MKLDNNDeconvForward { - public: - MKLDNNDeconvForward(const DeconvolutionParam ¶m, const NDArray &data, - const NDArray &weights, bool has_bias, - const NDArray &output); - const mkldnn::convolution_backward_data &GetFwd() const { return *fwd; } - const mkldnn::convolution_backward_data::primitive_desc &GetPd() const { - return *fwd_pd; - } +void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); + const auto ¶m = nnvm::get(attrs.parsed); + const auto tensors = MKLDNNDeconvFwd::Tensors(param.no_bias, inputs, outputs); + const auto &fwd = MKLDNNDeconvFwd::GetCached(param, tensors); - private: - std::shared_ptr fwd; - std::shared_ptr fwd_pd; -}; // class MKLDNNDeconvForward - -MKLDNNDeconvForward::MKLDNNDeconvForward(const DeconvolutionParam ¶m, - const NDArray &data, - const NDArray &weights, bool has_bias, - const NDArray &output) - : fwd_pd(GetDeconvFwdImpl(param, data, weights, has_bias, output)) { - fwd = std::make_shared(GetPd()); + fwd.ControlWeightsFormat(param.num_group, ctx.is_train, tensors.weights); + fwd.Execute(param.num_group, req[deconv::kOut], tensors); } -static void MKLDNNDeconvFwdBiasPostProcess( - const DeconvolutionParam ¶m, const OpContext &ctx, const NDArray &bias, - const std::vector &out_data) { - // add bias, broadcast bias to dim 1: channel - if (!param.no_bias) { - // MKLDNN only supports float right now. - typedef float DType; - Stream *s = ctx.get_stream(); - Tensor b = bias.data().get(s); - // The output data is stored in a special MKLDNN format, - // converts its format to the default format. - // Unfortunately, MKLDNN doesn't support broadcast. - auto out_data_def = out_data[deconv::kOut].Reorder2Default(); - Tensor out_cpu = out_data_def.data().get(s); - out_cpu += mshadow::expr::broadcast<1>(b, out_cpu.shape_); - } -} - -MKLDNNDeconvForward &GetDeconvFwd(const nnvm::NodeAttrs &attrs, - const NDArray &data, const NDArray &weights, - const NDArray *bias, const NDArray &output) { +MKLDNNDeconvFwd &MKLDNNDeconvFwd::GetCached(const DeconvolutionParam ¶m, + const Tensors &tensors) { + using deconv_fwd_map = std::unordered_map; #if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map - fwds; + static thread_local deconv_fwd_map fwds; #else - static MX_THREAD_LOCAL - std::unordered_map - fwds; + static MX_THREAD_LOCAL deconv_fwd_map fwds; #endif - const DeconvolutionParam ¶m = nnvm::get(attrs.parsed); DeconvSignature key(param); - // Here we can sign the conv op with NDArray because conv primitive will - // decide the right layout for the, so we only need to get the shape and the - // data type of the arrays. - key.AddSign(data); - key.AddSign(weights); - key.AddSign(output); - if (bias) key.AddSign(*bias); + key.AddSign(tensors.data); + key.AddSign(tensors.weights); + key.AddSign(tensors.out); + if (tensors.bias) { + key.AddSign(*tensors.bias); + } auto it = fwds.find(key); if (it == fwds.end()) { - bool has_bias = (bias != nullptr); - auto fwd = MKLDNNDeconvForward(param, data, weights, has_bias, output); + const MKLDNNDeconvFwd fwd(param, tensors); it = AddToCache(&fwds, key, fwd); } return it->second; } -void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data) { - TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); - const DeconvolutionParam ¶m = nnvm::get(attrs.parsed); - - auto &data = in_data[deconv::kData]; - auto &weight = in_data[deconv::kWeight]; - const NDArray *bias = param.no_bias ? nullptr : &in_data[deconv::kBias]; - - MKLDNNDeconvForward &fwd = - GetDeconvFwd(attrs, data, weight, bias, out_data[deconv::kOut]); +std::shared_ptr MKLDNNDeconvFwd::CreatePrimitiveDesc( + const DeconvolutionParam ¶m, const Tensors &tensors) { + DeconvDescCreator ddc(param, tensors.data, tensors.weights, tensors.bias, tensors.out); + const auto &engine = CpuEngine::Get()->get_engine(); + const auto pd = std::make_shared(ddc.CreateFwdDesc(), engine); + const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); }; + const auto get_weights_size = [&pd]() { return pd->weights_desc().get_size(); }; + const auto get_out_size = [&pd]() { return pd->dst_desc().get_size(); }; + + while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), get_out_size())) { + if (!pd->next_impl()) { + // ImposePlainWherePadding fails when all memory descriptors already have plain formats + // imposed, meaning there is no implementation with plain formats + CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size())) + << "No implementation of deconvolution forward propagation"; + *pd = deconv_fwd_pd_t(ddc.CreateFwdDesc(), engine); + } + } + return pd; +} - auto data_mem = data.GetMKLDNNDataReorder(fwd.GetPd().diff_dst_desc()); - const mkldnn::memory *weight_mem; - if (ctx.is_train) { +void MKLDNNDeconvFwd::ControlWeightsFormat(const uint32_t num_group, const bool is_train, + const NDArray &weights) const { + if (is_train) { // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it // to the default format for now. - if (weight.IsMKLDNNData()) - // This asks the engine to change the layout of the weight array after - // it's used. - weight.Reorder2DefaultAsync(); - weight_mem = - GetWeights(weight, fwd.GetPd().weights_desc(), param.num_group); + if (weights.IsMKLDNNData()) { + // This asks the engine to change the layout of the weights array after it's used. + weights.Reorder2DefaultAsync(); + } } else { - // For inference, we want to reorder the weight array so we don't need to + // For inference, we want to reorder the weights array so we don't need to // reorder data every time. - if (weight.IsDefaultData()) { - // We also need to modify the layout on the original weight array. The - // data conversion happens after the weight array is used. - weight.MKLDNNDataReorderAsync(fwd.GetPd().weights_desc()); - weight_mem = - GetWeights(weight, fwd.GetPd().weights_desc(), param.num_group); - + if (weights.IsDefaultData()) { + // We also need to modify the layout on the original weights array. + // The data conversion happens after the weights array is used. + weights.MKLDNNDataReorderAsync(IOLogicalSwapDesc(fwd_pd->weights_desc(), num_group)); } else { - weight_mem = weight.GetMKLDNNData(); - CHECK(weight_mem->get_desc() == fwd.GetPd().weights_desc()); + CHECK(weights.GetMKLDNNData()->get_desc() == + IOLogicalSwapDesc(fwd_pd->weights_desc(), num_group)); } } - mkldnn_output_t out_mem; - out_mem = CreateMKLDNNMem(out_data[deconv::kOut], fwd.GetPd().diff_src_desc(), - req[deconv::kOut]); - - mkldnn_args_map_t net_args; +} - net_args.insert({MKLDNN_ARG_DIFF_DST, *data_mem}); - net_args.insert({MKLDNN_ARG_WEIGHTS, *weight_mem}); - net_args.insert({MKLDNN_ARG_DIFF_SRC, *out_mem.second}); - MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); - CommitOutput(out_data[deconv::kOut], out_mem); - MKLDNNStream::Get()->Submit(); +void MKLDNNDeconvFwd::Execute(const uint32_t num_group, const OpReqType req, + const Tensors &tensors) const { + // MXNet (correctly) assumes that deconvolution is implemented using convolution primitives. + // For that, we would pass input tensor in place of output and output tensor in place of input + // (for appropriate convolution primitives: deconvolution forward = convolution backward data, + // deconvolution backward data = convolution forward). + // The convolution primitive expects weights tensor with the shape of + // (primitive_out_channels, primitive_in_channels, h, w), but with swapped input and output: + // primitive_out_channels = deconv_in_channels, primitive_in_channels = deconv_out_channels, + // so it becomes (deconv_in_channels, deconv_out_channels, h, w) and MXNet provides such tensor. + // + // MKLDNN deconvolution primitive also (as convolution) expects weights tensor with the shape of + // (primitive_out_channels, primitive_in_channels, h, w), but this time we don't swap input and + // output tensors, so: + // primitive_out_channels = deconv_out_channels, primitive_in_channels = deconv_in_channels, + // thus the current weights tensor won't fit (when deconv_out_channels != deconv_in_channels). + // However, underneath deconvolution MKLDNN also uses convolution, so even though it expects the + // weights tensor with the logical order of oihw, it wants its physical representation to + // match the order of iohw, which is the same as current weights tensor. + // + // So here we swap logical order of input and output dimensions for weights tensor just for + // MKLDNN operations. + IOLogicalSwapMKLDNNMem(tensors.weights, num_group); + { + mkldnn_args_map_t net_args; + const auto &out_mem = OutMem(req, tensors.out); + + net_args.insert({MKLDNN_ARG_SRC, *DataMem(tensors.data)}); + net_args.insert({MKLDNN_ARG_WEIGHTS, *WeightsMem(num_group, tensors.weights)}); + net_args.insert({MKLDNN_ARG_DST, *out_mem.second}); + if (tensors.bias) { + net_args.insert({MKLDNN_ARG_BIAS, *BiasMem(*tensors.bias)}); + } - MKLDNNDeconvFwdBiasPostProcess(param, ctx, *bias, out_data); + // CommitOutput should run after RegisterPrimArgs for memory dependency + MKLDNNStream::Get()->RegisterPrimArgs(*fwd, net_args); + CommitOutput(tensors.out, out_mem); + MKLDNNStream::Get()->Submit(); + } + IOLogicalSwapMKLDNNMem(tensors.weights, num_group); // swap back from oihw to iohw } -class MKLDNNDeconvBackwardData { - std::shared_ptr bwd; - public: - std::shared_ptr bwd_pd; - MKLDNNDeconvBackwardData(const DeconvolutionParam ¶m, const NDArray &data, - const NDArray &weights, const NDArray &output); - const mkldnn::convolution_forward &GetBwd() const { return *bwd; } - const mkldnn::convolution_forward::primitive_desc &GetDataPd() const { - return *bwd_pd; - } -}; +void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_NE(req[deconv::kWeight], kWriteInplace) << "Cannot write weights inplace"; -MKLDNNDeconvBackwardData::MKLDNNDeconvBackwardData( - const DeconvolutionParam ¶m, const NDArray &data, - const NDArray &weights, const NDArray &output) - : bwd_pd(GetDeconvBwdDataImpl(param, data, weights, false, output)) { - bwd = std::make_shared(GetDataPd()); -} + TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); + const auto ¶m = nnvm::get(attrs.parsed); + const auto read_tensors = MKLDNNDeconvBwd::ReadTensors(param.no_bias, inputs); + const auto write_tensors = MKLDNNDeconvBwd::WriteTensors(param.no_bias, outputs); + MKLDNNDeconvBwd &bwd = MKLDNNDeconvBwd::GetCached(param, read_tensors); -typedef ParamOpSign MKLDNNDeconvSignature; + bwd.Execute(param.num_group, req, read_tensors, write_tensors); +} -static inline MKLDNNDeconvBackwardData &GetDeconvBwdData( - const DeconvolutionParam ¶m, const NDArray &data, - const NDArray &weights, const NDArray &output) { +MKLDNNDeconvBwd &MKLDNNDeconvBwd::GetCached(const DeconvolutionParam ¶m, + const ReadTensors &read_tensors) { + using deconv_bwd_map = std::unordered_map; #if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map - bwds; + static thread_local deconv_bwd_map bwds; #else - static MX_THREAD_LOCAL std::unordered_map - bwds; + static MX_THREAD_LOCAL deconv_bwd_map bwds; #endif - MKLDNNDeconvSignature key(param); - // Here we can sign the conv op with NDArray because conv primitive will - // decide the right layout for the, so we only need to get the shape and the - // data type of the arrays. - key.AddSign(data); - key.AddSign(weights); - key.AddSign(output); + DeconvSignature key(param); + key.AddSign(read_tensors.data); + key.AddSign(read_tensors.weights); + key.AddSign(read_tensors.out_grad); + if (read_tensors.bias) { + key.AddSign(*read_tensors.bias); + } auto it = bwds.find(key); if (it == bwds.end()) { - auto bwd = MKLDNNDeconvBackwardData(param, data, weights, output); + const MKLDNNDeconvBwd bwd(param, read_tensors); it = AddToCache(&bwds, key, bwd); } return it->second; } -class MKLDNNDeconvBackwardWeights { - std::shared_ptr bwd; - - public: - std::shared_ptr - bwd_data_pd; - MKLDNNDeconvBackwardWeights( - const DeconvolutionParam ¶m, const NDArray &data, - const NDArray &weights, const NDArray &output, - const mkldnn::convolution_forward::primitive_desc &bwd_data_pd); - const mkldnn::convolution_backward_weights &GetBwd() const { return *bwd; } - const mkldnn::convolution_backward_weights::primitive_desc &GetWeightsPd() - const { - return *bwd_data_pd; +std::shared_ptr MKLDNNDeconvBwd::CreateDataPrimitiveDesc( + const DeconvolutionParam ¶m, const ReadTensors &read_tensors, + const deconv_fwd_pd_t &fwd_pd) { + DeconvDescCreator ddc(param, read_tensors.data, read_tensors.weights, nullptr, + read_tensors.out_grad); + const auto &engine = CpuEngine::Get()->get_engine(); + const auto pd = std::make_shared(ddc.CreateBwdDataDesc(), engine, fwd_pd); + const auto get_data_size = [&pd]() { return pd->diff_src_desc().get_size(); }; + const auto get_weights_size = [&pd]() { return pd->weights_desc().get_size(); }; + const auto get_out_size = [&pd]() { return pd->diff_dst_desc().get_size(); }; + + while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), get_out_size())) { + if (!pd->next_impl()) { + // ImposePlainWherePadding fails when all memory descriptors already have plain formats + // imposed, meaning there is no implementation with plain formats + CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size())) + << "No implementation of deconvolution backward propagation"; + *pd = deconv_bwd_data_pd_t(ddc.CreateBwdDataDesc(), engine, fwd_pd); + } } -}; - -MKLDNNDeconvBackwardWeights::MKLDNNDeconvBackwardWeights( - const DeconvolutionParam ¶m, const NDArray &data, - const NDArray &weights, const NDArray &output, - const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) - : bwd_data_pd(GetDeconvBwdWeightsImpl(param, data, weights, false, output, - bwd_data_pd)) { - bwd = std::make_shared(GetWeightsPd()); + return pd; } -static inline MKLDNNDeconvBackwardWeights &GetDeconvBwdWeights( - const DeconvolutionParam ¶m, const NDArray &data, - const NDArray &weights, const NDArray &output, - const mkldnn::convolution_forward::primitive_desc &bwd_data_pd) { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map - bwds; -#else - static MX_THREAD_LOCAL std::unordered_map - bwds; -#endif - MKLDNNDeconvSignature key(param); - // Here we can sign the conv op with NDArray because conv primitive will - // decide the right layout for the, so we only need to get the shape and the - // data type of the arrays. - key.AddSign(data); - key.AddSign(weights); - key.AddSign(output); - - auto it = bwds.find(key); - if (it == bwds.end()) { - auto bwd = - MKLDNNDeconvBackwardWeights(param, data, weights, output, bwd_data_pd); - auto ins_ret = bwds.insert( - std::pair(key, - bwd)); - CHECK(ins_ret.second); - it = ins_ret.first; +std::shared_ptr MKLDNNDeconvBwd::CreateWeightsPrimitiveDesc( + const DeconvolutionParam ¶m, const ReadTensors &read_tensors, + const deconv_fwd_pd_t &fwd_pd) { + DeconvDescCreator ddc(param, read_tensors.data, read_tensors.weights, read_tensors.bias, + read_tensors.out_grad); + const auto &engine = CpuEngine::Get()->get_engine(); + const auto pd = + std::make_shared(ddc.CreateBwdWeightsDesc(), engine, fwd_pd); + const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); }; + const auto get_weights_size = [&pd]() { return pd->diff_weights_desc().get_size(); }; + const auto get_out_size = [&pd]() { return pd->diff_dst_desc().get_size(); }; + + while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), get_out_size())) { + if (!pd->next_impl()) { + // ImposePlainWherePadding fails when all memory descriptors already have plain formats + // imposed, meaning there is no implementation with plain formats + CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size())) + << "No implementation of calculating deconvolution weights gradient"; + *pd = deconv_bwd_weights_pd_t(ddc.CreateBwdWeightsDesc(), engine, fwd_pd); + } } - return it->second; + return pd; } -void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); - const std::vector &in_grad = outputs; - const DeconvolutionParam ¶m = nnvm::get(attrs.parsed); - - auto &data = inputs[deconv::kData + 1]; - auto &weight = inputs[deconv::kWeight + 1]; - auto &out_grad = inputs[deconv::kOut]; - - CHECK_NE(req[deconv::kWeight], kWriteInplace) - << "cannot write weight inplace"; - MKLDNNDeconvBackwardData &bwd_data = - GetDeconvBwdData(param, data, weight, inputs[deconv::kOut]); - auto out_grad_mem = - out_grad.GetMKLDNNDataReorder(bwd_data.GetDataPd().src_desc()); - if (req[deconv::kData]) { - auto weight_mem = GetWeights(weight, bwd_data.GetDataPd().weights_desc(), - param.num_group); - auto in_grad_mem = - CreateMKLDNNMem(in_grad[deconv::kData], bwd_data.GetDataPd().dst_desc(), - req[deconv::kData]); - mkldnn_args_map_t net_args = {{MKLDNN_ARG_SRC, *out_grad_mem}, - {MKLDNN_ARG_WEIGHTS, *weight_mem}, - {MKLDNN_ARG_DST, *in_grad_mem.second}}; - MKLDNNStream::Get()->RegisterPrimArgs(bwd_data.GetBwd(), net_args); - CommitOutput(in_grad[deconv::kData], in_grad_mem); +void MKLDNNDeconvBwd::Execute(const uint32_t num_group, const std::vector &req, + const ReadTensors &read_tensors, + const WriteTensors &write_tensors) const { + // swaps are explained in MKLDNNDeconvFwd::Execute + IOSwapWeightsTensors(num_group, req, read_tensors.weights, write_tensors.weights_grad); + { + auto *const out_grad_mem = + ScheduleBwdData(num_group, req[deconv::kData], read_tensors, write_tensors); + ScheduleBwdWeights(num_group, req, read_tensors, write_tensors, out_grad_mem); + MKLDNNStream::Get()->Submit(); } - if (req[deconv::kWeight]) { - MKLDNNDeconvBackwardWeights &bwd_weights = GetDeconvBwdWeights( - param, data, weight, inputs[deconv::kOut], bwd_data.GetDataPd()); - if (bwd_data.GetDataPd().src_desc() != - bwd_weights.GetWeightsPd().src_desc()) - out_grad_mem = - out_grad.GetMKLDNNDataReorder(bwd_weights.GetWeightsPd().src_desc()); - auto data_mem = - data.GetMKLDNNDataReorder(bwd_weights.GetWeightsPd().diff_dst_desc()); - auto in_grad_weight = CreateMKLDNNWeightGrad( - in_grad[deconv::kWeight], - bwd_weights.GetWeightsPd().diff_weights_desc(), req[deconv::kWeight]); - - mkldnn_args_map_t net_args = { - {MKLDNN_ARG_SRC, *out_grad_mem}, - {MKLDNN_ARG_DIFF_DST, *data_mem}, - {MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second}}; - MKLDNNStream::Get()->RegisterPrimArgs(bwd_weights.GetBwd(), net_args); - CommitOutput(in_grad[deconv::kWeight], in_grad_weight); + IOSwapWeightsTensors(num_group, req, read_tensors.weights, write_tensors.weights_grad); +} + +const mkldnn::memory *MKLDNNDeconvBwd::ScheduleBwdData(const uint32_t num_group, + const OpReqType req, + const ReadTensors &read_tensors, + const WriteTensors &write_tensors) const { + if (req) { + mkldnn_args_map_t net_args; + auto *const out_grad_mem = OutGradMem(read_tensors.out_grad); + const auto &data_grad_mem = DataGradMem(req, write_tensors.data_grad); + + net_args.insert({MKLDNN_ARG_DIFF_DST, *out_grad_mem}); + net_args.insert({MKLDNN_ARG_WEIGHTS, *WeightsMem(num_group, read_tensors.weights)}); + net_args.insert({MKLDNN_ARG_DIFF_SRC, *data_grad_mem.second}); + + // CommitOutput should run after RegisterPrimArgs for memory dependency + MKLDNNStream::Get()->RegisterPrimArgs(*bwd_data, net_args); + CommitOutput(write_tensors.data_grad, data_grad_mem); + return out_grad_mem; } - MKLDNNStream::Get()->Submit(); + return nullptr; +} - if (!param.no_bias) { - typedef float DType; - Stream *s = ctx.get_stream(); - Tensor gbias = - in_grad[deconv::kBias].data().get(s); +void MKLDNNDeconvBwd::ScheduleBwdWeights(const uint32_t num_group, + const std::vector &req, + const ReadTensors &read_tensors, + const WriteTensors &write_tensors, + const mkldnn::memory *const out_grad_mem) const { + OpReqType weight_req = req[deconv::kWeight]; + OpReqType bias_req = req.size() > deconv::kBias ? req[deconv::kBias] : OpReqType::kNullOp; + if (weight_req || bias_req) { + mkldnn_args_map_t net_args; + const auto &weights_grad_mem = + WeightsGradMem(num_group, weight_req, write_tensors.weights_grad); + const auto &bias_grad_mem = BiasGradMem(bias_req, write_tensors.bias_grad); + + net_args.insert({MKLDNN_ARG_DIFF_DST, *OutGradMem(read_tensors.out_grad, out_grad_mem)}); + net_args.insert({MKLDNN_ARG_SRC, *DataMem(read_tensors.data)}); + net_args.insert({MKLDNN_ARG_DIFF_WEIGHTS, *weights_grad_mem.second}); + if (bias_grad_mem.second) { + net_args.insert({MKLDNN_ARG_DIFF_BIAS, *bias_grad_mem.second}); + } - NDArray temp = inputs[deconv::kOut]; - if (temp.IsMKLDNNData()) { - temp = temp.Reorder2Default(); + // CommitOutput should run after RegisterPrimArgs for memory dependency + MKLDNNStream::Get()->RegisterPrimArgs(*bwd_weights, net_args); + CommitOutput(write_tensors.weights_grad, weights_grad_mem); + if (bias_grad_mem.second) { + CommitOutput(*write_tensors.bias_grad, bias_grad_mem); } + } +} + + + +DeconvDescCreator::DeconvDescCreator(const DeconvolutionParam ¶m, const NDArray &data, + const NDArray &weights, const NDArray *const bias, + const NDArray &out) + : data_md(GetMemDesc(data)), + weights_md(GetDeconvWeightsDesc(weights, param.num_group)), + bias_md(bias ? GetMemDesc(*bias) : mkldnn::memory::desc()), + out_md(GetMemDesc(out)), + strides(param.stride.ndim()), + padding(param.pad.ndim()), + dilates(param.dilate.ndim()) { + // assuming only deconv2D is supported for now + CHECK_EQ(param.stride.ndim(), param.pad.ndim()); + CHECK_EQ(param.stride.ndim(), param.dilate.ndim()); + CHECK_EQ(param.stride.ndim(), 2); + for (int i = 0; i < param.stride.ndim(); ++i) { + strides[i] = param.stride[i]; + padding[i] = param.pad[i]; + dilates[i] = param.dilate[i] - 1; + } +} - Tensor grad = temp.data().get(s); - Assign(gbias, req[deconv::kBias], - mshadow::expr::sumall_except_dim<1>(grad)); +bool DeconvDescCreator::ImposePlainWherePadding(const size_t data_size, const size_t weights_size, + const size_t out_size) { + // Changing only one at a time, so maybe better implementations will be selected (than entirely + // plain one) + if (data_md.data.format_kind == dnnl_format_kind_any && data_size != GetMemDescSize(data_md)) { + data_md = GetDesc(data_md, GetDefaultFormat(data_md)); + return true; + } else if (out_md.data.format_kind == dnnl_format_kind_any && + out_size != GetMemDescSize(out_md)) { + out_md = GetDesc(out_md, GetDefaultFormat(out_md)); + return true; + } else if (weights_md.data.format_kind == dnnl_format_kind_any && + weights_size != GetMemDescSize(weights_md)) { + const int num_gr = (weights_md.data.ndims > data_md.data.ndims) ? weights_md.data.dims[0] : 1; + weights_md = IOLogicalSwapDesc(weights_md, num_gr); + weights_md = IOLogicalSwapDesc(GetDesc(weights_md, GetDefaultFormat(weights_md)), num_gr); + return true; } + return false; } } // namespace op diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index 60ebbfb97477..de0c249f52ab 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -469,10 +469,9 @@ def check_convolution_training(stype): @with_seed() -@unittest.skip("Flaky test https://github.com/apache/incubator-mxnet/issues/12579") def test_Deconvolution(): def check_Deconvolution_training(stype): - for shape in [(3, 3, 10), (3, 3, 10, 10)]: + for shape in [(3, 3, 10, 10)]: # testing only 2D for now data_tmp = np.random.randint(256, size=shape) data = mx.symbol.Variable('data', stype=stype) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index b525075c84aa..e7ac61e55296 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1660,22 +1660,38 @@ def test_deconvolution_forward_with_bias(): def check_deconvolution_forward_with_bias(shape=(1, 16, 5, 5), num_filter=32, num_group=1, kernel=(3, 3), pad=(1, 1)): x = mx.sym.Variable('x') w = mx.sym.Variable('w') - input_data = mx.random.uniform(-5, 5, shape, ctx=mx.cpu()) - y = mx.sym.Deconvolution(data=x, weight=w, num_filter=num_filter, num_group=num_group, kernel=kernel, no_bias=False, pad=pad) - exe = y.simple_bind(ctx=mx.cpu(), x=shape, grad_req='null') + b = mx.sym.Variable('b') + y_nb = mx.sym.Deconvolution(data=x, weight=w, num_filter=num_filter, num_group=num_group, kernel=kernel, no_bias=True, pad=pad) + y_b = mx.sym.Deconvolution(data=x, weight=w, bias=b, num_filter=num_filter, num_group=num_group, kernel=kernel, no_bias=False, pad=pad) + + + exe_nb = y_nb.simple_bind(ctx=mx.cpu(), x=shape, grad_req='null') + exe_b = y_b.simple_bind(ctx=mx.cpu(), x=shape, grad_req='null') + + + data = np.random.uniform(-5, 5, size=exe_b.arg_arrays[0].shape) + weights = np.random.normal(size=exe_b.arg_arrays[1].shape) + bias = np.random.normal(size=exe_b.arg_arrays[2].shape) + + def exe_forward(exe): + exe.arg_arrays[0][:] = data + exe.arg_arrays[1][:] = weights + if len(exe.arg_arrays) == 3: + exe.arg_arrays[2][:] = bias + return exe.forward(is_train=False)[0].asnumpy() + + out_nb = exe_forward(exe_nb) + out_b = exe_forward(exe_b) + bias = np.broadcast_to(bias, [np.prod(out_nb.shape[2:])] + [num_filter]).T + bias = np.broadcast_to(bias.reshape((num_filter, *out_nb.shape[2:])), out_b.shape) + assert_almost_equal(out_nb + bias, out_b) - exe.arg_arrays[0][:] = np.random.normal(size=exe.arg_arrays[0].shape) - exe.arg_arrays[1][:] = np.random.normal(size=exe.arg_arrays[1].shape) - exe.forward(is_train=False) - o = exe.outputs[0] - t = o.asnumpy() check_deconvolution_forward_with_bias((1, 16, 5), 32, 1, (3,), (1,)) check_deconvolution_forward_with_bias((32, 16, 5), 32, 1, (3,), (1,)) check_deconvolution_forward_with_bias((1, 16, 5, 5), 32, 1, (3, 3), (1, 1)) check_deconvolution_forward_with_bias((32, 16, 5, 5), 32, 1, (3, 3), (1, 1)) - def check_nearest_upsampling_with_shape(shapes, scale, root_scale): arr = {'arg_%d'%i: mx.random.uniform(-10.0, 10.0, shape, ctx=mx.cpu()).copyto(default_context()) for i, shape in zip(range(len(shapes)), shapes)} arr_grad = {'arg_%d'%i: mx.nd.zeros(shape) for i, shape in zip(range(len(shapes)), shapes)}