Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] add: numpy op rollaxis #17865

Merged
merged 1 commit into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'interp',
'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'atleast_1d', 'atleast_2d', 'atleast_3d',
'where', 'bincount', 'pad', 'cumsum', 'diag', 'diagonal']
'where', 'bincount', 'rollaxis', 'pad', 'cumsum', 'diag', 'diagonal']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -8178,6 +8178,38 @@ def cumsum(a, axis=None, dtype=None, out=None):
return _api_internal.cumsum(a, axis, dtype, out)


@set_module('mxnet.ndarray.numpy')
def rollaxis(a, axis, start=0):
"""
Roll the specified axis backwards, until it lies in a given position.
a
Input array.
axis : integer
The axis to roll backwards. The positions of the other axes do not
change relative to one another.
start: int, optional
The axis is rolled until it lies before this position.
The default, 0, results in a “complete” roll.

Returns
-------
res : ndarray
A view after applying rollaxis to `a` is returned.

-----
Examples
--------
>>> a = np.ones((3,4,5,6))
>>> np.rollaxis(a, 3, 1).shape
(3, 6, 4, 5)
>>> np.rollaxis(a, 2).shape
(5, 3, 4, 6)
>>> np.rollaxis(a, 1, 4).shape
(3, 5, 6, 4)
"""
return _npi.rollaxis(a, axis, start)


@set_module('mxnet.ndarray.numpy')
def diag(v, k=0):
"""
Expand Down
2 changes: 0 additions & 2 deletions python/mxnet/numpy/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@
'rate',
'real',
'result_type',
'rollaxis',
'roots',
'searchsorted',
'select',
Expand Down Expand Up @@ -180,7 +179,6 @@
rate = onp.rate
real = onp.real
result_type = onp.result_type
rollaxis = onp.rollaxis
roots = onp.roots
searchsorted = onp.searchsorted
select = onp.select
Expand Down
39 changes: 37 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul',
'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount',
'atleast_1d', 'atleast_2d', 'atleast_3d',
'pad', 'cumsum', 'diag', 'diagonal']
'pad', 'cumsum', 'rollaxis', 'diag', 'diagonal']

__all__ += fallback.__all__

Expand Down Expand Up @@ -10339,7 +10339,42 @@ def cumsum(a, axis=None, dtype=None, out=None):
[ 4, 9, 15]])
"""
return _mx_nd_np.cumsum(a, axis=axis, dtype=dtype, out=out)
# pylint: enable=redefined-outer-name


# pylint: disable=redefined-outer-name
@set_module('mxnet.numpy')
def rollaxis(a, axis, start=0):
"""
Roll the specified axis backwards, until it lies in a given position.

Parameters
----------
a : ndarray
Input array.
axis : integer
The axis to roll backwards. The positions of the other axes do not
change relative to one another.
start: int, optional
The axis is rolled until it lies before this position.
The default, 0, results in a “complete” roll.

Returns
-------
res : ndarray
A view after applying rollaxis to `a` is returned.

-----
Examples
--------
>>> a = np.ones((3,4,5,6))
>>> np.rollaxis(a, 3, 1).shape
(3, 6, 4, 5)
>>> np.rollaxis(a, 2).shape
(5, 3, 4, 6)
>>> np.rollaxis(a, 1, 4).shape
(3, 5, 6, 4)
"""
return _mx_nd_np.rollaxis(a, axis, start)


@set_module('mxnet.numpy')
Expand Down
37 changes: 36 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d',
'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'atleast_1d', 'atleast_2d', 'atleast_3d',
'where', 'bincount', 'pad', 'cumsum', 'diag', 'diagonal']
'where', 'bincount', 'rollaxis', 'pad', 'cumsum', 'diag', 'diagonal']


@set_module('mxnet.symbol.numpy')
Expand Down Expand Up @@ -7136,6 +7136,41 @@ def cumsum(a, axis=None, dtype=None, out=None):
return _npi.cumsum(a, axis=axis, dtype=dtype, out=out)


@set_module('mxnet.symbol.numpy')
def rollaxis(a, axis, start=0):
"""
Roll the specified axis backwards, until it lies in a given position.

Parameters
----------
a : ndarray
Input array.
axis : integer
The axis to roll backwards. The positions of the other axes do not
change relative to one another.
start: int, optional
The axis is rolled until it lies before this position.
The default, 0, results in a “complete” roll.

Returns
-------
res : ndarray
A view after applying rollaxis to `a` is returned.

-----
Examples
--------
>>> a = np.ones((3,4,5,6))
>>> np.rollaxis(a, 3, 1).shape
(3, 6, 4, 5)
>>> np.rollaxis(a, 2).shape
(5, 3, 4, 6)
>>> np.rollaxis(a, 1, 4).shape
(3, 5, 6, 4)
"""
return _npi.rollaxis(a, axis, start)


@set_module('mxnet.symbol.numpy')
def diag(v, k=0):
"""
Expand Down
98 changes: 98 additions & 0 deletions src/operator/numpy/np_matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,47 @@ void NumpyFlipForward(const nnvm::NodeAttrs& attrs,
NumpyFlipForwardImpl<xpu>(ctx, inputs, outputs, stride_, trailing_, flip_index);
}

struct NumpyRollaxisParam : public dmlc::Parameter<NumpyRollaxisParam> {
int axis;
int start;
DMLC_DECLARE_PARAMETER(NumpyRollaxisParam) {
DMLC_DECLARE_FIELD(axis)
.describe("The axis to roll backwards. "
"The positions of the other axes do not change relative to one another.");
DMLC_DECLARE_FIELD(start)
.set_default(0)
.describe("The axis is rolled until it lies before this position. "
"The default, 0, results in a “complete” roll.");
}
};

inline mxnet::TShape NumpyRollaxisShapeImpl(int axis,
int start,
const int& ndim) {
mxnet::TShape axes(ndim, -1);
if (axis < 0) {
axis += ndim;
}
if (start < 0) {
start += ndim;
}
if (axis < start) {
axes[start - 1] = axis;
} else {
axes[start] = axis;
}
int new_axis = 0;
for (int i = 0; i < axes.ndim(); i++) {
if (axes[i] < 0) {
if (new_axis == axis) {
new_axis++;
}
axes[i] = new_axis++;
}
}
return axes;
}

struct NumpyMoveaxisParam : public dmlc::Parameter<NumpyMoveaxisParam> {
mxnet::TShape source;
mxnet::TShape destination;
Expand Down Expand Up @@ -601,6 +642,63 @@ void NumpyMoveaxisCompute(const nnvm::NodeAttrs& attrs,
})
}

template<typename xpu>
void NumpyRollaxisCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req[0], kWriteTo) << "Rollaxis does not support inplace";
mxnet::TShape axes;
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed);
axes = NumpyRollaxisShapeImpl(param.axis, param.start, inputs[0].ndim());
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, {
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
})
}

template<typename xpu>
void NumpyRollaxisBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed);
int axis_origin = param.axis;
int start_origin = param.start;
int ndim = inputs[0].ndim();

int axis;
int start;

if (axis_origin < 0) {
axis_origin += ndim;
}

if (start_origin < 0) {
start_origin += ndim;
}

if (axis_origin < start_origin) {
axis = start_origin - 1;
start = axis_origin;
} else {
axis = start_origin;
start = axis_origin + 1;
}
mxnet::TShape axes;
axes = NumpyRollaxisShapeImpl(axis, start, inputs[0].ndim());
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, {
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
})
}

struct NumpyRot90Param : public dmlc::Parameter<NumpyRot90Param> {
int k;
dmlc::optional<mxnet::TShape> axes;
Expand Down
64 changes: 64 additions & 0 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace op {
DMLC_REGISTER_PARAMETER(NumpyTransposeParam);
DMLC_REGISTER_PARAMETER(NumpyRollParam);
DMLC_REGISTER_PARAMETER(NumpyMoveaxisParam);
DMLC_REGISTER_PARAMETER(NumpyRollaxisParam);
DMLC_REGISTER_PARAMETER(NumpyRot90Param);
DMLC_REGISTER_PARAMETER(NumpyReshapeParam);
DMLC_REGISTER_PARAMETER(NumpyXReshapeParam);
Expand Down Expand Up @@ -1190,6 +1191,69 @@ NNVM_REGISTER_OP(_npi_roll)
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(NumpyRollParam::__FIELDS__());

bool NumpyRollaxisShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed);
// check 1 input, 1 output
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

// check transpose dimentions no more than 6
mxnet::TShape& shp = (*in_attrs)[0];
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";

// check axis and start range
CHECK_GE(param.axis, -shp.ndim())
<< "axis must be within the range of "
<< -shp.ndim() << " and " << shp.ndim() - 1;
CHECK_LT(param.axis, shp.ndim())
<< "axis must be within the range of "
<< -shp.ndim() << " and " << shp.ndim() - 1;
CHECK_GE(param.start, -shp.ndim())
<< "start must be within the range of "
<< -shp.ndim() << " and " << shp.ndim();
CHECK_LE(param.start, shp.ndim())
<< "start must be within the range of "
<< -shp.ndim() << " and " << shp.ndim();

// generate output shape
mxnet::TShape ret(shp.ndim(), -1);
mxnet::TShape axes;

axes = NumpyRollaxisShapeImpl(param.axis, param.start, shp.ndim());
for (int i = 0; i < shp.ndim(); ++i) {
CHECK(axes[i] < static_cast<int64_t>(shp.ndim()));
ret[i] = shp[axes[i]];
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret);
return shape_is_known(ret);
}

NNVM_REGISTER_OP(_npi_rollaxis)
.describe(R"code(Roll the specified axis backwards,
until it lies in a given position.)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyRollaxisParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data"};
})
.set_attr<mxnet::FInferShape>("FInferShape", NumpyRollaxisShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_npi_rollaxis_backward"})
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(NumpyRollaxisParam::__FIELDS__());

NNVM_REGISTER_OP(_npi_rollaxis_backward)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyRollaxisParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisBackward<cpu>);

template<>
void NumpyFlipForwardImpl<cpu>(const OpContext& ctx,
const std::vector<TBlob>& inputs,
Expand Down
6 changes: 6 additions & 0 deletions src/operator/numpy/np_matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ NNVM_REGISTER_OP(_backward_npi_flip)
NNVM_REGISTER_OP(_np_moveaxis)
.set_attr<FCompute>("FCompute<gpu>", NumpyMoveaxisCompute<gpu>);

NNVM_REGISTER_OP(_npi_rollaxis)
.set_attr<FCompute>("FCompute<gpu>", NumpyRollaxisCompute<gpu>);

NNVM_REGISTER_OP(_npi_rollaxis_backward)
.set_attr<FCompute>("FCompute<gpu>", NumpyRollaxisBackward<gpu>);

NNVM_REGISTER_OP(_npi_rot90)
.set_attr<FCompute>("FCompute<gpu>", NumpyRot90Compute<gpu>);

Expand Down
Loading