diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 6da2c0641153..3538d5480c8f 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -6097,9 +6097,10 @@ def unravel_index(indices, shape, order='C'): # pylint: disable=redefined-outer- if order == 'C': if isinstance(indices, numeric_types): return _np.unravel_index(indices, shape) - return tuple(_npi.unravel_index_fallback(indices, shape=shape)) - else: - raise NotImplementedError('Do not support column-major (Fortran-style) order at this moment') + if isinstance(indices, NDArray): + return tuple(_api_internal.unravel_index(indices, shape)) + raise TypeError('Do not support type {} as indices.'.format(str(type(indices)))) + raise NotImplementedError('Do not support column-major (Fortran-style) order at this moment') def flatnonzero(a): diff --git a/src/api/operator/tensor/unravel.cc b/src/api/operator/tensor/unravel.cc new file mode 100644 index 000000000000..3c60d8ed4e41 --- /dev/null +++ b/src/api/operator/tensor/unravel.cc @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file unravel.cc + * \brief Implementation of the API of functions in src/operator/tensor/ravel.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/tensor/ravel.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.unravel_index") + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_unravel_index"); + nnvm::NodeAttrs attrs; + op::RavelParam param; + if (args[1].type_code() == kNull) { + param.shape = TShape(-1, 0); + } else if (args[1].type_code() == kDLInt) { + param.shape = TShape(1, args[1].operator int64_t()); + } else { + param.shape = TShape(args[1].operator ObjectRef()); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + NDArray* inputs[] = {args[0].operator mxnet::NDArray *()}; + int num_inputs = 1; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + if (num_outputs == 1) { + *ret = ndoutputs[0]; + } else { + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + } + }); + +} // namespace mxnet diff --git a/src/operator/tensor/ravel.cc b/src/operator/tensor/ravel.cc index 4b98887dabe8..2f471f8e8677 100644 --- a/src/operator/tensor/ravel.cc +++ b/src/operator/tensor/ravel.cc @@ -61,6 +61,7 @@ Examples:: NNVM_REGISTER_OP(_unravel_index) .add_alias("unravel_index") + .add_alias("_npi_unravel_index") .describe( R"code(Converts an array of flat indices into a batch of index arrays. The operator follows numpy conventions so a single multi index is given by a column of the output matrix. The leading dimension may be left unspecified by using -1 as placeholder. diff --git a/src/operator/tensor/ravel.h b/src/operator/tensor/ravel.h index d192b35060f2..0fd6069f94ad 100644 --- a/src/operator/tensor/ravel.h +++ b/src/operator/tensor/ravel.h @@ -25,6 +25,7 @@ #define MXNET_OPERATOR_TENSOR_RAVEL_H_ #include +#include #include #include #include "../mshadow_op.h" @@ -42,6 +43,11 @@ struct RavelParam : public dmlc::Parameter { .set_default(mxnet::TShape()) .describe("Shape of the array into which the multi-indices apply."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream shape_s; + shape_s << shape; + (*dict)["shape"] = shape_s.str(); + } }; inline bool RavelOpShape(const nnvm::NodeAttrs& attrs, @@ -75,7 +81,7 @@ inline bool UnravelOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1); CHECK_GT(shape.ndim(), 0) << "Empty shape parameter for unravel operator."; const mxnet::TShape& in_shape = (*in_attrs)[0]; - if (in_shape.ndim() > 0) { + if (in_shape.ndim() >= 0) { mxnet::TShape out_shape(in_shape.ndim() + 1, -1); out_shape[0] = shape.ndim(); for (int i = 0; i < in_shape.ndim(); ++i) { diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 0db209c5774f..4d2588ac8cf1 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -10053,8 +10053,20 @@ def forward(self, x): @use_np -@pytest.mark.skip(reason='Test hangs. Tracked in #18144') -def test_np_unravel_index(): +@pytest.mark.parametrize('ishape', [ + 2, 5, + (), (1,), (4,), + (2, 2), (2, 4), (3, 5), + (2, 2, 2), (2, 3, 2), (2, 3, 4), +]) +@pytest.mark.parametrize('rshape', [ + 10, (15,), + (3, 4), (4, 5), + (2,3,4) +]) +@pytest.mark.parametrize('dtype', [np.uint8, np.int8, np.int32, np.int64]) +@pytest.mark.parametrize('hybridize', [True, False]) +def test_np_unravel_index(ishape, rshape, dtype, hybridize): class TestUnravel_index(HybridBlock): def __init__(self, shape, order='C') : super(TestUnravel_index, self).__init__() @@ -10064,44 +10076,33 @@ def __init__(self, shape, order='C') : def forward(self, a): return np.unravel_index(a, self._shape, self._order) - in_shapes = [ - 2, 5, - (), (1,), (4,), - (2, 2), (2, 4), (3, 5), - (2, 2, 2), (2, 3, 2), (2, 3, 4), - ] - unravel_shapes = [ - 10, (15,), - (3, 4), (4, 5), - (2,3,4) - ] - dtypes = [np.uint8, np.int8, np.int32, np.int64] - for hybridize, ishape, dtype, rshape in itertools.product([False, True], in_shapes, dtypes, unravel_shapes): - rtol = 1e-2 if dtype == np.float16 else 1e-3 - atol = 1e-4 if dtype == np.float16 else 1e-5 - test_unravel_index = TestUnravel_index(rshape) - if hybridize: - test_unravel_index.hybridize() - if type(ishape) == int and hybridize: - x = np.array([ishape], dtype=dtype) - np_out = onp.unravel_index(x.asnumpy(), rshape) - else: - x = np.random.uniform(0, 8, size=ishape).astype(dtype) - np_out = onp.unravel_index(x.asnumpy(), rshape) - mx_out = test_unravel_index(x) - assert len(mx_out) == len(np_out) - for elem_mx, elem_np in zip(mx_out, np_out): - assert elem_mx.asnumpy().shape == elem_np.shape - assert_almost_equal(elem_mx.asnumpy(), elem_np, rtol=rtol, atol=atol) - # no backward function for unravel_index operator - # Test imperative once again - mx_out = np.unravel_index(x, rshape) + rtol = 1e-2 if dtype == np.float16 else 1e-3 + atol = 1e-4 if dtype == np.float16 else 1e-5 + test_unravel_index = TestUnravel_index(rshape) + if hybridize: + test_unravel_index.hybridize() + if type(ishape) == int and hybridize: + x = np.array([ishape], dtype=dtype) np_out = onp.unravel_index(x.asnumpy(), rshape) - assert len(mx_out) == len(np_out) - for elem_mx, elem_np in zip(mx_out, np_out): - assert elem_mx.asnumpy().shape == elem_np.shape - assert_almost_equal(elem_mx.asnumpy(), elem_np, rtol=rtol, atol=atol) + else: + x = np.random.uniform(0, 8, size=ishape).astype(dtype) + np_out = onp.unravel_index(x.asnumpy(), rshape) + mx_out = test_unravel_index(x) + assert len(mx_out) == len(np_out) + for elem_mx, elem_np in zip(mx_out, np_out): + assert elem_mx.asnumpy().shape == elem_np.shape + assert_almost_equal(elem_mx.asnumpy(), elem_np, rtol=rtol, atol=atol) + # no backward function for unravel_index operator + + # Test imperative once again + mx_out = np.unravel_index(x, rshape) + np_out = onp.unravel_index(x.asnumpy(), rshape) + print(np_out) + assert len(mx_out) == len(np_out) + for elem_mx, elem_np in zip(mx_out, np_out): + assert elem_mx.asnumpy().shape == elem_np.shape + assert_almost_equal(elem_mx.asnumpy(), elem_np, rtol=rtol, atol=atol) @use_np