Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
mdf: move rollaxis impl to matrix_op
Browse files Browse the repository at this point in the history
  • Loading branch information
yijunc committed Mar 19, 2020
1 parent 3439d47 commit ba318d5
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 195 deletions.
96 changes: 96 additions & 0 deletions src/operator/numpy/np_matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,45 @@ void NumpyFlipForward(const nnvm::NodeAttrs& attrs,
NumpyFlipForwardImpl<xpu>(ctx, inputs, outputs, stride_, trailing_, flip_index);
}

struct NumpyRollaxisParam : public dmlc::Parameter<NumpyRollaxisParam> {
int axis;
int start;
DMLC_DECLARE_PARAMETER(NumpyRollaxisParam) {
DMLC_DECLARE_FIELD(axis)
.describe("The axis to roll backwards. The positions of the other axes do not change relative to one another.");
DMLC_DECLARE_FIELD(start)
.set_default(0)
.describe("The axis is rolled until it lies before this position. The default, 0, results in a “complete” roll.");
}
};

inline mxnet::TShape NumpyRollaxisShapeImpl(int axis,
int start,
const int& ndim) {
mxnet::TShape axes(ndim, -1);
if (axis < 0) {
axis += ndim;
}
if (start < 0){
start += ndim;
}
if (axis < start){
axes[start - 1] = axis;
} else {
axes[start] = axis;
}
int new_axis = 0;
for(int i = 0; i < axes.ndim(); i++){
if (axes[i] < 0){
if (new_axis == axis){
new_axis++;
}
axes[i] = new_axis++;
}
}
return axes;
}

struct NumpyMoveaxisParam : public dmlc::Parameter<NumpyMoveaxisParam> {
mxnet::TShape source;
mxnet::TShape destination;
Expand Down Expand Up @@ -601,6 +640,63 @@ void NumpyMoveaxisCompute(const nnvm::NodeAttrs& attrs,
})
}

template<typename xpu>
void NumpyRollaxisCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req[0], kWriteTo) << "Rollaxis does not support inplace";
mxnet::TShape axes;
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed);
axes = NumpyRollaxisShapeImpl(param.axis, param.start, inputs[0].ndim());
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, {
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
})
}

template<typename xpu>
void NumpyRollaxisBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed);
int axis_origin = param.axis;
int start_origin = param.start;
int ndim = inputs[0].ndim();

int axis;
int start;

if (axis_origin < 0) {
axis_origin += ndim;
}

if (start_origin < 0) {
start_origin += ndim;
}

if (axis_origin < start_origin){
axis = start_origin - 1;
start = axis_origin;
} else {
axis = start_origin;
start = axis_origin + 1;
}
mxnet::TShape axes;
axes = NumpyRollaxisShapeImpl(axis, start, inputs[0].ndim());
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, {
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
})
}

struct NumpyRot90Param : public dmlc::Parameter<NumpyRot90Param> {
int k;
dmlc::optional<mxnet::TShape> axes;
Expand Down
56 changes: 56 additions & 0 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace op {
DMLC_REGISTER_PARAMETER(NumpyTransposeParam);
DMLC_REGISTER_PARAMETER(NumpyRollParam);
DMLC_REGISTER_PARAMETER(NumpyMoveaxisParam);
DMLC_REGISTER_PARAMETER(NumpyRollaxisParam);
DMLC_REGISTER_PARAMETER(NumpyRot90Param);
DMLC_REGISTER_PARAMETER(NumpyReshapeParam);
DMLC_REGISTER_PARAMETER(NumpyXReshapeParam);
Expand Down Expand Up @@ -1190,6 +1191,61 @@ NNVM_REGISTER_OP(_npi_roll)
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(NumpyRollParam::__FIELDS__());

bool NumpyRollaxisShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed);
// check 1 input, 1 output
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

// check transpose dimentions no more than 6
mxnet::TShape& shp = (*in_attrs)[0];
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";

// check axis and start range
CHECK_GE(param.axis, -shp.ndim()) << "axis must be within the range of " << -shp.ndim() << " and " << shp.ndim() - 1;
CHECK_LT(param.axis, shp.ndim()) << "axis must be within the range of " << -shp.ndim() << " and " << shp.ndim() - 1;
CHECK_GE(param.start, -shp.ndim()) << "start must be within the range of " << -shp.ndim() << " and " << shp.ndim();
CHECK_LE(param.start, shp.ndim()) << "start must be within the range of " << -shp.ndim() << " and " << shp.ndim();

// generate output shape
mxnet::TShape ret(shp.ndim(), -1);
mxnet::TShape axes;

axes = NumpyRollaxisShapeImpl(param.axis, param.start, shp.ndim());
for (int i = 0; i < shp.ndim(); ++i) {
CHECK(axes[i] < static_cast<int64_t>(shp.ndim()));
ret[i] = shp[axes[i]];
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret);
return shape_is_known(ret);
}

NNVM_REGISTER_OP(_npi_rollaxis)
.describe(R"code(Roll the specified axis backwards,
until it lies in a given position.)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyRollaxisParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data"};
})
.set_attr<mxnet::FInferShape>("FInferShape", NumpyRollaxisShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_npi_rollaxis_backward"})
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(NumpyRollaxisParam::__FIELDS__());

NNVM_REGISTER_OP(_npi_rollaxis_backward)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyRollaxisParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisBackward<cpu>);

template<>
void NumpyFlipForwardImpl<cpu>(const OpContext& ctx,
const std::vector<TBlob>& inputs,
Expand Down
6 changes: 6 additions & 0 deletions src/operator/numpy/np_matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ NNVM_REGISTER_OP(_backward_npi_flip)
NNVM_REGISTER_OP(_np_moveaxis)
.set_attr<FCompute>("FCompute<gpu>", NumpyMoveaxisCompute<gpu>);

NNVM_REGISTER_OP(_npi_rollaxis)
.set_attr<FCompute>("FCompute<gpu>", NumpyRollaxisCompute<gpu>);

NNVM_REGISTER_OP(_npi_rollaxis_backward)
.set_attr<FCompute>("FCompute<gpu>", NumpyRollaxisBackward<gpu>);

NNVM_REGISTER_OP(_npi_rot90)
.set_attr<FCompute>("FCompute<gpu>", NumpyRot90Compute<gpu>);

Expand Down
65 changes: 0 additions & 65 deletions src/operator/numpy/np_rollaixs_op.cc

This file was deleted.

117 changes: 0 additions & 117 deletions src/operator/numpy/np_rollaxis_op-inl.h

This file was deleted.

13 changes: 0 additions & 13 deletions src/operator/numpy/np_rollaxis_op.cu

This file was deleted.

0 comments on commit ba318d5

Please sign in to comment.