Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[FEATURE] Add oneDNN support for numpy concatenate operator #20652

Merged
merged 6 commits into from
Oct 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cpp-package/example/charRNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ Symbol LSTMUnroll(int num_lstm_layer, int sequence_length, int input_dim,
hidden_all.push_back(hidden);
}

auto hidden_concat = isTrain? Concat(hidden_all, hidden_all.size(), 0) : hidden_all[0];
auto hidden_concat =
isTrain ? Concat(hidden_all, hidden_all.size(), dmlc::optional<int>(0)) : hidden_all[0];
auto cls_weight = Symbol::Variable("cls_weight");
auto cls_bias = Symbol::Variable("cls_bias");
auto pred = FullyConnected("pred", hidden_concat, cls_weight, cls_bias, input_dim);
Expand Down
1 change: 0 additions & 1 deletion python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,6 @@
'_mod',
'_not_equal',
'_npi_column_stack',
'_npi_concatenate',
'_npi_copysign',
'_npi_cross',
'_npi_dot',
Expand Down
8 changes: 4 additions & 4 deletions src/api/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,17 @@ MXNET_REGISTER_API("_npi.concatenate")
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_concatenate");
nnvm::NodeAttrs attrs;
op::NumpyConcatenateParam param;
op::ConcatParam param;
int arg_size = args.num_args;
param.num_args = arg_size - 2;
if (args[arg_size - 2].type_code() == kNull) {
param.axis = dmlc::nullopt;
param.dim = dmlc::nullopt;
} else {
param.axis = args[arg_size - 2].operator int();
param.dim = args[arg_size - 2].operator int();
}
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyConcatenateParam>(&attrs);
SetAttrDict<op::ConcatParam>(&attrs);
int num_inputs = arg_size - 2;
std::vector<NDArray*> inputs;
inputs.reserve(num_inputs);
Expand Down
34 changes: 26 additions & 8 deletions src/operator/nn/concat-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ enum ConcatOpOutputs { kOut };

struct ConcatParam : public dmlc::Parameter<ConcatParam> {
int num_args;
int dim;
dmlc::optional<int> dim;
DMLC_DECLARE_PARAMETER(ConcatParam) {
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1).describe("Number of inputs to be concated.");
DMLC_DECLARE_FIELD(dim).set_default(1).describe("the dimension to be concated.");
DMLC_DECLARE_FIELD(dim)
.set_default(dmlc::optional<int>(1))
.describe("the dimension to be concated.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream num_args_s, dim_s;
Expand All @@ -66,7 +68,7 @@ class ConcatOp {
public:
void Init(const ConcatParam& param) {
this->size_ = param.num_args;
this->dimension_ = param.dim;
this->dimension_ = param.dim.has_value() ? param.dim.value() : 0;
}

void Forward(const OpContext& ctx,
Expand Down Expand Up @@ -140,10 +142,18 @@ void ConcatCompute(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
MSHADOW_TYPE_SWITCH(inputs[concat_enum::kData0].type_flag_, DType, {
std::vector<TBlob> in_data(param.num_args);
for (int i = 0; i < param.num_args; i++) {
if (!param.dim.has_value()) {
in_data[i] = inputs[i].reshape(mxnet::TShape(1, inputs[i].shape_.Size()));
} else {
in_data[i] = inputs[i];
}
}
MSHADOW_TYPE_SWITCH_WITH_BOOL(in_data[concat_enum::kData0].type_flag_, DType, {
ConcatOp<xpu, DType> op;
op.Init(param);
op.Forward(ctx, inputs, req, outputs);
op.Forward(ctx, in_data, req, outputs);
});
}

Expand Down Expand Up @@ -209,10 +219,18 @@ void ConcatGradCompute(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
MSHADOW_TYPE_SWITCH(inputs[concat_enum::kOut].type_flag_, DType, {
std::vector<TBlob> out_data(param.num_args);
for (int i = 0; i < param.num_args; i++) {
if (!param.dim.has_value()) {
out_data[i] = outputs[i].reshape(mxnet::TShape(1, outputs[i].shape_.Size()));
} else {
out_data[i] = outputs[i];
}
}
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[concat_enum::kOut].type_flag_, DType, {
ConcatOp<xpu, DType> op;
op.Init(param);
op.Backward(ctx, inputs[concat_enum::kOut], req, outputs);
op.Backward(ctx, inputs[concat_enum::kOut], req, out_data);
});
}

Expand Down Expand Up @@ -318,7 +336,7 @@ void ConcatCSRImpl(const nnvm::NodeAttrs& attrs,
using namespace csr;
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
int num_args = param.num_args;
int concat_dim = param.dim;
int concat_dim = param.dim.has_value() ? param.dim.value() : 0;
CHECK_EQ(inputs.size(), num_args);
CHECK_EQ(outputs.size(), 1);
int axis = CheckAxis(concat_dim, inputs[0].shape().ndim());
Expand Down
39 changes: 27 additions & 12 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@ bool ConcatShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
mxnet::TShape dshape;
dim_t size = 0;
int param_dim = param_.dim.has_value() ? param_.dim.value() : 0;
bool has_unknown_dim_size = false;
int axis = -1;
if (!param_.dim.has_value()) {
for (int i = 0; i < param_.num_args; ++i) {
(*in_shape)[i] = Shape1((*in_shape)[i].Size());
}
}
for (int i = 0; i < param_.num_args; ++i) {
mxnet::TShape tmp = (*in_shape)[i];
if (tmp.ndim() > 0) {
axis = CheckAxis(param_.dim, tmp.ndim());
axis = CheckAxis(param_dim, tmp.ndim());
has_unknown_dim_size = !mxnet::dim_size_is_known(tmp, axis) || has_unknown_dim_size;
size += tmp[axis];
tmp[axis] = -1;
Expand All @@ -54,7 +60,7 @@ bool ConcatShape(const nnvm::NodeAttrs& attrs,

mxnet::TShape tmp = (*out_shape)[0];
if (tmp.ndim() > 0) {
axis = CheckAxis(param_.dim, tmp.ndim());
axis = CheckAxis(param_dim, tmp.ndim());
tmp[axis] = -1;
shape_assign(&dshape, tmp);
}
Expand Down Expand Up @@ -89,11 +95,12 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape dshape;
index_t size = 0;
std::vector<int> zero_indices;
int axis = -1;
int axis = -1;
int param_dim = param_.dim.has_value() ? param_.dim.value() : 0;
for (int i = 0; i < param_.num_args; ++i) {
mxnet::TShape tmp = (*in_shape)[i];
if (tmp.ndim() > 0) {
axis = CheckAxis(param_.dim, tmp.ndim());
axis = CheckAxis(param_dim, tmp.ndim());
if (!mxnet::dim_size_is_known(tmp, axis)) {
zero_indices.emplace_back(i);
} else {
Expand All @@ -107,7 +114,7 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,

mxnet::TShape tmp = (*out_shape)[0];
if (tmp.ndim() > 0) {
axis = CheckAxis(param_.dim, tmp.ndim());
axis = CheckAxis(param_dim, tmp.ndim());
tmp[axis] = -1;
shape_assign(&dshape, tmp);
}
Expand Down Expand Up @@ -193,13 +200,14 @@ inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs,
auto& out_stype = out_attrs->at(0);
bool dispatched = false;
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kCSRStorage) && param.dim == 0) {
int param_dim = param.dim.has_value() ? param.dim.value() : 0;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kCSRStorage) && param_dim == 0) {
dispatched =
storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, DispatchMode::kFComputeEx);
}
#if MXNET_USE_ONEDNN == 1
if (!dispatched && dev_mask == mshadow::cpu::kDevMask &&
common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.dim > 0) {
common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
dispatched =
storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx);
}
Expand All @@ -225,10 +233,8 @@ inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int>* out_attrs) {
DispatchMode wanted_mode;
#if MXNET_USE_ONEDNN == 1
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(out_attrs->size(), in_attrs->size() - 1);
if (dev_mask == mshadow::cpu::kDevMask &&
common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.dim > 0)
if (dev_mask == mshadow::cpu::kDevMask && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage))
wanted_mode = DispatchMode::kFComputeEx;
else
#endif // MXNET_USE_ONEDNN == 1
Expand All @@ -251,8 +257,9 @@ bool SupportDNNLConcat(const std::vector<NDArray>& arrs) {
return false;
int ndim = arr.shape().ndim();
const int dnnl_ndims = arr.GetDNNLData()->get_desc().data.ndims;
if (!(ndim == 2 || ndim == 4) || ndim != dnnl_ndims)
if ((ndim != 2 && ndim != 4) || ndim != dnnl_ndims) {
return false;
}
}
return true;
}
Expand Down Expand Up @@ -347,12 +354,14 @@ DMLC_REGISTER_PARAMETER(ConcatParam);
NNVM_REGISTER_OP(Concat)
MXNET_ADD_SPARSE_OP_ALIAS(concat)
.add_alias("concat")
.add_alias("_npi_concatenate")
.describe(R"code(Joins input arrays along a given axis.

.. note:: `Concat` is deprecated. Use `concat` instead.

The dimensions of the input arrays should be the same except the axis along
which they will be concatenated.
which they will be concatenated. With dimension parameter ``None`` input
arrays are flattened before concatenating them along axis 0.
The dimension of the output array along the concatenated axis will be equal
to the sum of the corresponding dimensions of the input arrays.

Expand All @@ -376,6 +385,11 @@ Example::
[ 7., 7.],
[ 8., 8.]]

concat(x,y,z,dim=None) = [1., 1., 2., 2.,
3., 3., 4., 4.,
5., 5., 6., 6.,
7., 7., 8., 8.]

Note that you cannot concat x,y,z along dimension 1 since dimension
0 is not the same for all the input arrays.

Expand All @@ -397,6 +411,7 @@ Example::
.add_arguments(ConcatParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_Concat)
.add_alias("_backward_np_concat")
.set_num_inputs([](const NodeAttrs& attrs) {
#if MXNET_USE_ONEDNN == 1
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
Expand Down
5 changes: 4 additions & 1 deletion src/operator/nn/concat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,17 @@ static void ConcatComputeExGPU(const nnvm::NodeAttrs& attrs,
}

NNVM_REGISTER_OP(Concat)
.add_alias("_npi_concatenate")
.set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", ConcatComputeExGPU);

NNVM_REGISTER_OP(_rnn_param_concat)
.set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", ConcatComputeExGPU);

NNVM_REGISTER_OP(_backward_Concat).set_attr<FCompute>("FCompute<gpu>", ConcatGradCompute<gpu>);
NNVM_REGISTER_OP(_backward_Concat)
.add_alias("_backward_np_concat")
.set_attr<FCompute>("FCompute<gpu>", ConcatGradCompute<gpu>);

} // namespace op
} // namespace mxnet
8 changes: 5 additions & 3 deletions src/operator/nn/dnnl/dnnl_concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ void DNNLConcatForward(const nnvm::NodeAttrs& attrs,
TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]);
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
const int num_in_data = param.num_args;
const int concat_dim = param.dim;
int concat_dim = param.dim.has_value() ? param.dim.value() : 0;
concat_dim = CheckAxis(concat_dim, in_data[concat_enum::kData0].shape().ndim());
std::vector<dnnl::memory::desc> data_md;
std::vector<const dnnl::memory*> data_mem;
data_md.reserve(num_in_data);
Expand Down Expand Up @@ -96,7 +97,8 @@ void DNNLConcatBackward(const nnvm::NodeAttrs& attrs,
TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]);
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
const int num_in_data = param.num_args;
const int axis = param.dim;
int concat_dim = param.dim.has_value() ? param.dim.value() : 0;
concat_dim = CheckAxis(concat_dim, outputs[concat_enum::kData0].shape().ndim());
const auto gradz_mem = inputs[0].GetDNNLData();
/* init the offset */
dnnl::memory::dims offsets(outputs[0].shape().ndim());
Expand All @@ -112,7 +114,7 @@ void DNNLConcatBackward(const nnvm::NodeAttrs& attrs,
auto from_md = gradz_mem->get_desc().submemory_desc(diff_src_tz, offsets);
auto from_mem =
new dnnl::memory(from_md, gradz_mem->get_engine(), gradz_mem->get_data_handle());
offsets[axis] += diff_src_tz[axis];
offsets[concat_dim] += diff_src_tz[concat_dim];

std::unordered_map<int, dnnl::memory> net_args(
{{DNNL_ARG_FROM, *gradz_mem}, {DNNL_ARG_TO, *gradi_mem.second}});
Expand Down
Loading