Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-382] Shape and Size Operator #10889

Merged
merged 31 commits into from
Jun 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4635957
Shape Operator
May 10, 2018
38c15d9
cuda
May 10, 2018
6540678
size op
May 11, 2018
01d1b95
lint issues
May 11, 2018
b513dbe
docs example
May 11, 2018
125c3cd
add docs, change op name to avoid conflict, add convenience confluent
May 11, 2018
ac96ef1
change name to _nd
May 11, 2018
93ffddc
fix test cases, add new kernel
May 15, 2018
3d578d3
test name fix.
May 15, 2018
ef43d2f
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
May 15, 2018
1b7ba47
solve gpu memory problem for size and shape
Jun 11, 2018
f8cc278
Merge pull request #3 from haojin2/shape_op
anirudhacharya Jun 11, 2018
eb74750
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Jun 12, 2018
08346da
get rid of FIgnoreInputs attr of shape_nd
Jun 12, 2018
d2f7999
Merge pull request #4 from haojin2/shape_op
anirudhacharya Jun 12, 2018
93fc294
Merge branch 'shape' of https://github.com/anirudhacharya/incubator-m…
Jun 13, 2018
3f84b1a
op name change
Jun 13, 2018
3d8f560
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Jun 13, 2018
2074b46
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Jun 18, 2018
4b1164e
fix
Jun 22, 2018
ee97196
Merge branch 'master' of https://github.com/apache/incubator-mxnet into
Jun 25, 2018
10cb562
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Jun 26, 2018
cbec1e5
retrigger CI
Jun 26, 2018
039e6d4
retrigger CI
Jun 26, 2018
ee110fd
retrigger CI
Jun 26, 2018
460c315
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Jun 26, 2018
67cbbfe
trigger CI
Jun 26, 2018
804dfb4
fix comments
Jun 28, 2018
f17f9c8
cpplint
Jun 28, 2018
c188404
nit
Jun 29, 2018
de64b97
trigger CI
Jun 29, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api/python/ndarray/ndarray.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ The `ndarray` package provides several classes:
:nosignatures:

NDArray.T
NDArray.shape_array
NDArray.size_array
NDArray.reshape
NDArray.reshape_like
NDArray.flatten
Expand Down Expand Up @@ -375,6 +377,8 @@ The `ndarray` package provides several classes:
:nosignatures:

cast
shape_array
size_array
reshape
reshape_like
flatten
Expand Down
4 changes: 4 additions & 0 deletions docs/api/python/symbol/symbol.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ Composite multiple symbols into a new one by an operator.
:nosignatures:

Symbol.astype
Symbol.shape_array
Symbol.size_array
Symbol.reshape
Symbol.reshape_like
Symbol.flatten
Expand Down Expand Up @@ -373,6 +375,8 @@ Composite multiple symbols into a new one by an operator.
:nosignatures:

cast
shape_array
size_array
reshape
reshape_like
flatten
Expand Down
16 changes: 16 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,22 @@ def flatten(self, *args, **kwargs):
"""
return op.flatten(self, *args, **kwargs)

def shape_array(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`shape_array`.

The arguments are the same as for :py:func:`shape_array`, with
this array as data.
"""
return op.shape_array(self, *args, **kwargs)

def size_array(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`size_array`.

The arguments are the same as for :py:func:`size_array`, with
this array as data.
"""
return op.size_array(self, *args, **kwargs)

def expand_dims(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`expand_dims`.

Expand Down
16 changes: 16 additions & 0 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1982,6 +1982,22 @@ def flatten(self, *args, **kwargs):
"""
return op.flatten(self, *args, **kwargs)

def shape_array(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`shape_array`.

The arguments are the same as for :py:func:`shape_op`, with
this array as data.
"""
return op.shape_array(self, *args, **kwargs)

def size_array(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`size_array`.

The arguments are the same as for :py:func:`size_array`, with
this array as data.
"""
return op.size_array(self, *args, **kwargs)

def expand_dims(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`expand_dims`.

Expand Down
6 changes: 4 additions & 2 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,13 +1251,15 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write',
np.dtype(np.float32): 1e-3,
np.dtype(np.float64): 1e-5,
np.dtype(np.uint8): 0,
np.dtype(np.int32): 0}
np.dtype(np.int32): 0,
np.dtype(np.int64): 0}
elif isinstance(tol, numbers.Number):
tol = {np.dtype(np.float16): tol,
np.dtype(np.float32): tol,
np.dtype(np.float64): tol,
np.dtype(np.uint8): tol,
np.dtype(np.int32): tol}
np.dtype(np.int32): tol,
np.dtype(np.int64): tol}

assert len(ctx_list) > 1
if isinstance(sym, Symbol):
Expand Down
92 changes: 92 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,98 @@ NNVM_REGISTER_OP(reshape_like)
.add_argument("lhs", "NDArray-or-Symbol", "First input.")
.add_argument("rhs", "NDArray-or-Symbol", "Second input.");

void ShapeComputeCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
const TBlob& in_data = inputs[0];
const TBlob& out_data = outputs[0];
memcpy(out_data.dptr_, in_data.shape_.data(), in_data.ndim() * sizeof(int64_t));
}

NNVM_REGISTER_OP(shape_array)
.describe(R"code(Returns a 1D int64 array containing the shape of data.

Example::

shape_array([[1,2,3,4], [5,6,7,8]]) = [2,4]

)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<FCompute>("FCompute<cpu>", ShapeComputeCPU)
.set_attr<nnvm::FInferShape>("FInferShape",
[](const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
TShape target_shape(1);
target_shape[0] = in_attrs->at(0).ndim();
SHAPE_ASSIGN_CHECK(*out_attrs, 0, target_shape);
return !shape_is_none(out_attrs->at(0));
})
.set_attr<nnvm::FInferType>("FInferType",
[](const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
return out_attrs->at(0) != -1;
})
.add_argument("data", "NDArray-or-Symbol", "Input Array.");

void SizeComputeCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
const TBlob& in_data = inputs[0];
const TBlob& out_data = outputs[0];
const index_t size_var = in_data.Size();
memcpy(out_data.dptr_, &size_var, 1U * sizeof(int64_t));
}

NNVM_REGISTER_OP(size_array)
.describe(R"code(Returns a 1D int64 array containing the size of data.

Example::

size_array([[1,2,3,4], [5,6,7,8]]) = [8]

)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<FCompute>("FCompute<cpu>", SizeComputeCPU)
.set_attr<nnvm::FInferShape>("FInferShape",
[](const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, 1U);
return !shape_is_none(out_attrs->at(0));
})
.set_attr<nnvm::FInferType>("FInferType",
[](const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
return out_attrs->at(0) != -1;
})
.add_argument("data", "NDArray-or-Symbol", "Input Array.");

DMLC_REGISTER_PARAMETER(CastParam);
NNVM_REGISTER_OP(Cast)
Expand Down
46 changes: 46 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,52 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs)
NNVM_REGISTER_OP(reshape_like)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);

void ShapeComputeGPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
const TBlob& in_data = inputs[0];
const TBlob& out_data = outputs[0];
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
cudaMemcpyAsync(out_data.dptr_,
in_data.shape_.data(),
in_data.ndim() * sizeof(int64_t),
cudaMemcpyHostToDevice,
mshadow::Stream<gpu>::GetStream(s));
}

NNVM_REGISTER_OP(shape_array)
.set_attr<FCompute>("FCompute<gpu>", ShapeComputeGPU);

void SizeComputeGPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
const TBlob& in_data = inputs[0];
const TBlob& out_data = outputs[0];
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
const index_t size_var = in_data.Size();
cudaMemcpyAsync(out_data.dptr_,
&size_var,
1U * sizeof(int64_t),
cudaMemcpyHostToDevice,
mshadow::Stream<gpu>::GetStream(s));
}

NNVM_REGISTER_OP(size_array)
.set_attr<FCompute>("FCompute<gpu>", SizeComputeGPU);

NNVM_REGISTER_OP(Cast)
.set_attr<FCompute>("FCompute<gpu>", CastCompute<gpu>);

Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,8 +931,8 @@ def check_fluent_regular(func, kwargs, shape=(5, 17, 1), equal_nan=False):
assert almost_equal(regular.asnumpy(), fluent.asnumpy(), equal_nan=equal_nan)

for func in ['flatten', 'norm', 'round', 'rint', 'fix', 'floor', 'ceil', 'trunc', 'zeros_like',
'ones_like', 'abs', 'sign', 'sin', 'cos', 'degrees', 'radians',
'exp', 'expm1', 'square', 'reciprocal', 'argmax_channel']:
'ones_like', 'abs', 'sign', 'sin', 'cos', 'degrees', 'radians', 'exp', 'expm1',
'square', 'reciprocal', 'argmax_channel', 'shape_array', 'size_array']:
check_fluent_regular(func, {})

for func in ['arccosh', 'arcsin', 'arccos', 'arctan', 'tan', 'sinh', 'cosh', 'tanh',
Expand Down
18 changes: 18 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,24 @@ def fsigmoid(a):
check_symbolic_forward(y, [xa], [ya])
check_symbolic_backward(y, [xa], [np.ones(shape)], [ya * (1 - ya)])

@with_seed()
def test_shape_array():
for i in range(1,6):
shape = rand_shape_nd(i)
x = np.random.ranf(shape)
y = mx.nd.shape_array(mx.nd.array(x))
expected_y = np.shape(x)
same(y.asnumpy(), expected_y)

@with_seed()
def test_size_array():
for i in range(1,6):
shape = rand_shape_nd(i)
x = np.random.ranf(shape)
y = mx.nd.size_array(mx.nd.array(x))
expected_y = np.size(x)
same(y.asnumpy(), expected_y)

@with_seed()
def test_hard_sigmoid():
def fhardsigmoid(a, alpha=0.2, beta=0.5):
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def check_fluent_regular(func, kwargs, shape=(5, 17, 1), equal_nan=False):
equal_nan=equal_nan)

for func in ['flatten', 'norm', 'round', 'rint', 'fix', 'floor', 'ceil', 'trunc', 'zeros_like',
'ones_like', 'abs', 'sign', 'sin', 'cos', 'degrees', 'radians',
'exp', 'expm1', 'square', 'reciprocal', 'argmax_channel']:
'ones_like', 'abs', 'sign', 'sin', 'cos', 'degrees', 'radians', 'exp', 'expm1',
'square', 'reciprocal', 'argmax_channel', 'shape_array', 'size_array']:
check_fluent_regular(func, {})

for func in ['arccosh', 'arcsin', 'arccos', 'arctan', 'tan', 'sinh', 'cosh', 'tanh',
Expand Down