diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 0bbe263cfc76..773b4a5772b2 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -517,6 +517,45 @@ void NumpyFlipForward(const nnvm::NodeAttrs& attrs, NumpyFlipForwardImpl(ctx, inputs, outputs, stride_, trailing_, flip_index); } +struct NumpyRollaxisParam : public dmlc::Parameter { + int axis; + int start; + DMLC_DECLARE_PARAMETER(NumpyRollaxisParam) { + DMLC_DECLARE_FIELD(axis) + .describe("The axis to roll backwards. The positions of the other axes do not change relative to one another."); + DMLC_DECLARE_FIELD(start) + .set_default(0) + .describe("The axis is rolled until it lies before this position. The default, 0, results in a “complete” roll."); + } +}; + +inline mxnet::TShape NumpyRollaxisShapeImpl(int axis, + int start, + const int& ndim) { + mxnet::TShape axes(ndim, -1); + if (axis < 0) { + axis += ndim; + } + if (start < 0){ + start += ndim; + } + if (axis < start){ + axes[start - 1] = axis; + } else { + axes[start] = axis; + } + int new_axis = 0; + for(int i = 0; i < axes.ndim(); i++){ + if (axes[i] < 0){ + if (new_axis == axis){ + new_axis++; + } + axes[i] = new_axis++; + } + } + return axes; +} + struct NumpyMoveaxisParam : public dmlc::Parameter { mxnet::TShape source; mxnet::TShape destination; @@ -601,6 +640,63 @@ void NumpyMoveaxisCompute(const nnvm::NodeAttrs& attrs, }) } +template +void NumpyRollaxisCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req[0], kWriteTo) << "Rollaxis does not support inplace"; + mxnet::TShape axes; + const NumpyRollaxisParam& param = nnvm::get(attrs.parsed); + axes = NumpyRollaxisShapeImpl(param.axis, param.start, inputs[0].ndim()); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, { + TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); + }) +} + +template +void NumpyRollaxisBackward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mshadow::expr; + const NumpyRollaxisParam& param = nnvm::get(attrs.parsed); + int axis_origin = param.axis; + int start_origin = param.start; + int ndim = inputs[0].ndim(); + + int axis; + int start; + + if (axis_origin < 0) { + axis_origin += ndim; + } + + if (start_origin < 0) { + start_origin += ndim; + } + + if (axis_origin < start_origin){ + axis = start_origin - 1; + start = axis_origin; + } else { + axis = start_origin; + start = axis_origin + 1; + } + mxnet::TShape axes; + axes = NumpyRollaxisShapeImpl(axis, start, inputs[0].ndim()); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, { + TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); + }) +} + struct NumpyRot90Param : public dmlc::Parameter { int k; dmlc::optional axes; diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index e9d269dd54d6..10696ca04e82 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -34,6 +34,7 @@ namespace op { DMLC_REGISTER_PARAMETER(NumpyTransposeParam); DMLC_REGISTER_PARAMETER(NumpyRollParam); DMLC_REGISTER_PARAMETER(NumpyMoveaxisParam); +DMLC_REGISTER_PARAMETER(NumpyRollaxisParam); DMLC_REGISTER_PARAMETER(NumpyRot90Param); DMLC_REGISTER_PARAMETER(NumpyReshapeParam); DMLC_REGISTER_PARAMETER(NumpyXReshapeParam); @@ -1190,6 +1191,61 @@ NNVM_REGISTER_OP(_npi_roll) .add_argument("data", "NDArray-or-Symbol", "Input ndarray") .add_arguments(NumpyRollParam::__FIELDS__()); +bool NumpyRollaxisShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + const NumpyRollaxisParam& param = nnvm::get(attrs.parsed); + // check 1 input, 1 output + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + // check transpose dimentions no more than 6 + mxnet::TShape& shp = (*in_attrs)[0]; + CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; + + // check axis and start range + CHECK_GE(param.axis, -shp.ndim()) << "axis must be within the range of " << -shp.ndim() << " and " << shp.ndim() - 1; + CHECK_LT(param.axis, shp.ndim()) << "axis must be within the range of " << -shp.ndim() << " and " << shp.ndim() - 1; + CHECK_GE(param.start, -shp.ndim()) << "start must be within the range of " << -shp.ndim() << " and " << shp.ndim(); + CHECK_LE(param.start, shp.ndim()) << "start must be within the range of " << -shp.ndim() << " and " << shp.ndim(); + + // generate output shape + mxnet::TShape ret(shp.ndim(), -1); + mxnet::TShape axes; + + axes = NumpyRollaxisShapeImpl(param.axis, param.start, shp.ndim()); + for (int i = 0; i < shp.ndim(); ++i) { + CHECK(axes[i] < static_cast(shp.ndim())); + ret[i] = shp[axes[i]]; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret); + return shape_is_known(ret); +} + +NNVM_REGISTER_OP(_npi_rollaxis) +.describe(R"code(Roll the specified axis backwards, +until it lies in a given position.)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", NumpyRollaxisShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCompute", NumpyRollaxisCompute) +.set_attr("FGradient", ElemwiseGradUseNone{"_npi_rollaxis_backward"}) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(NumpyRollaxisParam::__FIELDS__()); + +NNVM_REGISTER_OP(_npi_rollaxis_backward) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FCompute", NumpyRollaxisBackward); + template<> void NumpyFlipForwardImpl(const OpContext& ctx, const std::vector& inputs, diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index c9e896bc5b57..5222ed960727 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -112,6 +112,12 @@ NNVM_REGISTER_OP(_backward_npi_flip) NNVM_REGISTER_OP(_np_moveaxis) .set_attr("FCompute", NumpyMoveaxisCompute); +NNVM_REGISTER_OP(_npi_rollaxis) +.set_attr("FCompute", NumpyRollaxisCompute); + +NNVM_REGISTER_OP(_npi_rollaxis_backward) +.set_attr("FCompute", NumpyRollaxisBackward); + NNVM_REGISTER_OP(_npi_rot90) .set_attr("FCompute", NumpyRot90Compute); diff --git a/src/operator/numpy/np_rollaixs_op.cc b/src/operator/numpy/np_rollaixs_op.cc deleted file mode 100644 index a697b36df38e..000000000000 --- a/src/operator/numpy/np_rollaixs_op.cc +++ /dev/null @@ -1,65 +0,0 @@ - -#include "./np_rollaxis_op-inl.h" - -namespace mxnet { -namespace op { - -DMLC_REGISTER_PARAMETER(NumpyRollaxisParam); - -bool NumpyRollaxisShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { - const NumpyRollaxisParam& param = nnvm::get(attrs.parsed); - // check 1 input, 1 output - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - - // check transpose dimentions no more than 6 - mxnet::TShape& shp = (*in_attrs)[0]; - CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; - - // check axis and start range - CHECK_GE(param.axis, -shp.ndim()) << "axis must be within the range of " << -shp.ndim() << " and " << shp.ndim() - 1; - CHECK_LT(param.axis, shp.ndim()) << "axis must be within the range of " << -shp.ndim() << " and " << shp.ndim() - 1; - CHECK_GE(param.start, -shp.ndim()) << "start must be within the range of " << -shp.ndim() << " and " << shp.ndim(); - CHECK_LE(param.start, shp.ndim()) << "start must be within the range of " << -shp.ndim() << " and " << shp.ndim(); - - // generate output shape - mxnet::TShape ret(shp.ndim(), -1); - mxnet::TShape axes; - - axes = NumpyRollaxisShapeImpl(param.axis, param.start, shp.ndim()); - for (int i = 0; i < shp.ndim(); ++i) { - CHECK(axes[i] < static_cast(shp.ndim())); - ret[i] = shp[axes[i]]; - } - SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret); - return shape_is_known(ret); -} - -NNVM_REGISTER_OP(_npi_rollaxis) -.describe(R"code(Roll the specified axis backwards, -until it lies in a given position.)code" ADD_FILELINE) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"data"}; - }) -.set_attr("FInferShape", NumpyRollaxisShape) -.set_attr("FInferType", ElemwiseType<1, 1>) -.set_attr("FCompute", NumpyRollaxisCompute) -.set_attr("FGradient", ElemwiseGradUseNone{"_npi_rollaxis_backward"}) -.add_argument("data", "NDArray-or-Symbol", "Input ndarray") -.add_arguments(NumpyRollaxisParam::__FIELDS__()); - -NNVM_REGISTER_OP(_npi_rollaxis_backward) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("TIsBackward", true) -.set_attr("FCompute", NumpyRollaxisBackward); - -} // namespace op -} // namespace mxnet diff --git a/src/operator/numpy/np_rollaxis_op-inl.h b/src/operator/numpy/np_rollaxis_op-inl.h deleted file mode 100644 index d0c5e15bc538..000000000000 --- a/src/operator/numpy/np_rollaxis_op-inl.h +++ /dev/null @@ -1,117 +0,0 @@ -#ifndef MXNET_OPERATOR_NUMPY_NP_ROLLAXIS_OP_INL_H_ -#define MXNET_OPERATOR_NUMPY_NP_ROLLAXIS_OP_INL_H_ - -#include "../operator_common.h" -#include -#include "../tensor/matrix_op-inl.h" -#include "../nn/concat-inl.h" -#include "../../common/utils.h" -#include "../mxnet_op.h" -#include "../operator_common.h" -#include "../elemwise_op_common.h" -#include "../tensor/broadcast_reduce_op.h" - -namespace mxnet { -namespace op { - -struct NumpyRollaxisParam : public dmlc::Parameter { - int axis; - int start; - DMLC_DECLARE_PARAMETER(NumpyRollaxisParam) { - DMLC_DECLARE_FIELD(axis) - .describe("The axis to roll backwards. The positions of the other axes do not change relative to one another."); - DMLC_DECLARE_FIELD(start) - .set_default(0) - .describe("The axis is rolled until it lies before this position. The default, 0, results in a “complete” roll."); - } -}; - -inline mxnet::TShape NumpyRollaxisShapeImpl(int axis, - int start, - const int& ndim) { - mxnet::TShape axes(ndim, -1); - if (axis < 0) { - axis += ndim; - } - if (start < 0){ - start += ndim; - } - if (axis < start){ - axes[start - 1] = axis; - } else { - axes[start] = axis; - } - int new_axis = 0; - for(int i = 0; i < axes.ndim(); i++){ - if (axes[i] < 0){ - if (new_axis == axis){ - new_axis++; - } - axes[i] = new_axis++; - } - } - return axes; -} - - -template -void NumpyRollaxisCompute(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(inputs.size(), 1U); - CHECK_EQ(outputs.size(), 1U); - CHECK_EQ(req[0], kWriteTo) << "Rollaxis does not support inplace"; - mxnet::TShape axes; - const NumpyRollaxisParam& param = nnvm::get(attrs.parsed); - axes = NumpyRollaxisShapeImpl(param.axis, param.start, inputs[0].ndim()); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, { - TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); - }) -} - -template -void NumpyRollaxisBackward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - using namespace mshadow; - using namespace mshadow::expr; - const NumpyRollaxisParam& param = nnvm::get(attrs.parsed); - int axis_origin = param.axis; - int start_origin = param.start; - int ndim = inputs[0].ndim(); - - int axis; - int start; - - if (axis_origin < 0) { - axis_origin += ndim; - } - - if (start_origin < 0) { - start_origin += ndim; - } - - if (axis_origin < start_origin){ - axis = start_origin - 1; - start = axis_origin; - } else { - axis = start_origin; - start = axis_origin + 1; - } - mxnet::TShape axes; - axes = NumpyRollaxisShapeImpl(axis, start, inputs[0].ndim()); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, { - TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); - }) -} - -} // namespace op -} // namespace mxnet - -#endif // MXNET_OPERATOR_NUMPY_NP_ROLLAXIS_OP_INL_H_ diff --git a/src/operator/numpy/np_rollaxis_op.cu b/src/operator/numpy/np_rollaxis_op.cu deleted file mode 100644 index 77914e51ca2c..000000000000 --- a/src/operator/numpy/np_rollaxis_op.cu +++ /dev/null @@ -1,13 +0,0 @@ -#include "./np_rollaxis_op-inl.h" - -namespace mxnet{ -namespace op{ - -NNVM_REGISTER_OP(_npi_rollaxis) -.set_attr("FCompute", NumpyRollaxisCompute); - -NNVM_REGISTER_OP(_npi_rollaxis_backward) -.set_attr("FCompute", NumpyRollaxisBackward); - -} -} \ No newline at end of file