diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index a1c37a9478a7..9d3388b23510 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -479,6 +479,11 @@ static inline __device__ void atomicAdd(mshadow::half::half_t *address, } while (assumed != old); } +// Overload atomicAdd to work for signed int64 on all architectures +static inline __device__ void atomicAdd(int64_t *address, int64_t val) { + atomicAdd(reinterpret_cast(address), static_cast(val)); // NOLINT +} + template __device__ inline DType ldg(const DType* address) { #if __CUDA_ARCH__ >= 350 diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 15ad59f5528b..081e40a62103 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -132,6 +132,50 @@ inline int get_num_threads(const int N) { LOG(FATAL) << "ndim=" << NDim << "too large "; \ } +#define MXNET_NO_INT8_TYPE_SWITCH(type, DType, ...) \ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kUint8: \ + LOG(FATAL) << "This operation does not " \ + "support int8 or uint8"; \ + break; \ + case mshadow::kInt8: \ + LOG(FATAL) << "This operation does not " \ + "support int8 or uint8"; \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + /*! * \brief assign the val to out according diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 735da31b8b41..10905b538f18 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -137,6 +137,46 @@ inline void SparseEmbeddingOpBackwardRspImpl(const OpContext& ctx, } +template +inline typename std::enable_if<(!std::is_same::value), void>::type +GatherNDBackwardImpl(int N, int M, int K, + const mshadow::Shape<10> strides, + DType* out, + const DType* data, + const IType* indices, + mshadow::Stream *s) { +#pragma omp parallel for + for (int i = 0; i < N; i++) { + int offset = 0; + for (int j = 0; j < M; ++j) { + offset += strides[j] * static_cast(indices[j*N + i]); + } + for (int j = 0; j < K; ++j) { +#pragma omp atomic + out[offset + j] += data[i * K + j]; + } + } +} + +template +inline typename std::enable_if::value, void>::type +GatherNDBackwardImpl(int N, int M, int K, + const mshadow::Shape<10> strides, + DType* out, + const DType* data, + const IType* indices, + mshadow::Stream *s) { + for (int i = 0; i < N; i++) { + int offset = 0; + for (int j = 0; j < M; ++j) { + offset += strides[j] * static_cast(indices[j*N + i]); + } + for (int j = 0; j < K; ++j) { + out[offset + j] += data[i * K + j]; + } + } +} + DMLC_REGISTER_PARAMETER(EmbeddingParam); DMLC_REGISTER_PARAMETER(TakeParam); DMLC_REGISTER_PARAMETER(OneHotParam); @@ -443,8 +483,7 @@ Examples:: NNVM_REGISTER_OP(gather_nd) .describe(R"code(Gather elements or slices from `data` and store to a tensor whose -shape is defined by `indices`. `gather_nd` and `scatter_nd` are inverse functions -to each other. +shape is defined by `indices`. Given `data` with shape `(X_0, X_1, ..., X_{N-1})` and indices with shape `(M, Y_0, ..., Y_{K-1})`, the output will have shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})`, @@ -476,13 +515,14 @@ Examples:: .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { auto p = nnvm::Node::Create(); - p->attrs.op = nnvm::Op::Get("scatter_nd"); + p->attrs.op = nnvm::Op::Get("_backward_gather_nd"); p->attrs.name = n->attrs.name + "_backward"; p->inputs.push_back(ograds[0]); p->inputs.push_back(n->inputs[1]); p->control_deps.emplace_back(n); auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices", {n->inputs[1]}, nullptr, &n); + std::vector ret; ret.emplace_back(nnvm::NodeEntry{p, 0, 0}); ret.emplace_back(nnvm::NodeEntry{zero, 0, 0}); @@ -492,10 +532,8 @@ Examples:: .add_argument("data", "NDArray-or-Symbol", "data") .add_argument("indices", "NDArray-or-Symbol", "indices"); - NNVM_REGISTER_OP(scatter_nd) .describe(R"code(Scatters data into a new tensor according to indices. -`gather_nd` and `scatter_nd` are inverse functions to each other. Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape `(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`, @@ -510,6 +548,12 @@ The elements in output is defined as follows:: all other entries in output are 0. +.. warning:: + + If the indices have duplicates, the result will be non-deterministic and + the gradient of `scatter_nd` will not be correct!! + + Examples:: data = [2, 3, 0] @@ -548,11 +592,73 @@ Examples:: .add_argument("indices", "NDArray-or-Symbol", "indices") .add_arguments(ScatterNDParam::__FIELDS__()); +NNVM_REGISTER_OP(_backward_gather_nd) +.describe(R"code(Accumulates data according to indices and get the result. It's the backward of +`gather_nd`. + +Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape +`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`, +where `M <= N`. If `M == N`, data shape should simply be `(Y_0, ..., Y_{K-1})`. + +The elements in output is defined as follows:: + + output[indices[0, y_0, ..., y_{K-1}], + ..., + indices[M-1, y_0, ..., y_{K-1}], + x_M, ..., x_{N-1}] += data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] + +all other entries in output are 0 or the original value if AddTo is triggered. + +Examples:: + + data = [2, 3, 0] + indices = [[1, 1, 0], [0, 1, 0]] + shape = (2, 2) + _backward_gather_nd(data, indices, shape) = [[0, 0], [2, 3]] # Same as scatter_nd + + # The difference between scatter_nd and scatter_nd_acc is the latter will accumulate + # the values that point to the same index. + + data = [2, 3, 0] + indices = [[1, 1, 0], [1, 1, 0]] + shape = (2, 2) + _backward_gather_nd(data, indices, shape) = [[0, 0], [0, 5]] + +)code") +.set_num_outputs(1) +.set_num_inputs(2) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "indices"}; + }) +.set_attr("FInferShape", ScatterNDShape) +.set_attr("FInferType", ScatterNDType) +.set_attr("FCompute", GatherNDBackward) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + auto p = nnvm::Node::Create(); + p->attrs.op = nnvm::Op::Get("gather_nd"); + p->attrs.name = n->attrs.name + "_backward"; + p->inputs.push_back(ograds[0]); + p->inputs.push_back(n->inputs[1]); + p->control_deps.emplace_back(n); + auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices", + {n->inputs[1]}, nullptr, &n); + std::vector ret; + ret.emplace_back(nnvm::NodeEntry{p, 0, 0}); + ret.emplace_back(nnvm::NodeEntry{zero, 0, 0}); + return ret; + }) +.set_attr("TIsBackward", true) +.add_argument("data", "NDArray-or-Symbol", "data") +.add_argument("indices", "NDArray-or-Symbol", "indices") +.add_arguments(ScatterNDParam::__FIELDS__()); + NNVM_REGISTER_OP(_scatter_set_nd) .describe(R"code(This operator has the same functionality as scatter_nd except that it does not reset the elements not indexed by the input index `NDArray` in the input data `NDArray`. - .. note:: This operator is for internal use only. Examples:: diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 4021f2b3a217..762d8fd64c2b 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -179,6 +179,32 @@ inline void SparseEmbeddingOpBackwardRspImpl(const OpContext& ctx, }); } +struct backward_gather_nd_gpu { + template + MSHADOW_XINLINE static void Map(int i, int N, int M, int K, + const mshadow::Shape<10> strides, + DType* out, const DType* data, + const IType* indices) { + int offset = 0; + for (int j = 0; j < M; ++j) { + offset += strides[j] * static_cast(indices[j*N + i]); + } + for (int j = 0; j < K; ++j) { + atomicAdd(out + (offset + j), data[i * K + j]); + } + } +}; + +template +inline void GatherNDBackwardImpl(int N, int M, int K, + const mshadow::Shape<10> strides, + DType* out, + const DType* data, + const IType* indices, + mshadow::Stream *s) { + mxnet_op::Kernel::Launch(s, N, N, M, K, strides, out, data, indices); +} + NNVM_REGISTER_OP(Embedding) .set_attr("FCompute", EmbeddingOpForward); @@ -209,6 +235,9 @@ NNVM_REGISTER_OP(gather_nd) NNVM_REGISTER_OP(scatter_nd) .set_attr("FCompute", ScatterNDForward); +NNVM_REGISTER_OP(_backward_gather_nd) +.set_attr("FCompute", GatherNDBackward); + NNVM_REGISTER_OP(_scatter_set_nd) .set_attr("FCompute", ScatterSetNDForward); } // namespace op diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 4043e76cfdae..7323f81c09ac 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -1131,10 +1131,10 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs, int K = oshape.ProdShape(M, oshape.ndim()); mshadow::Shape<10> strides; for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride; + if (kWriteTo == req[0]) { + Fill(s, outputs[0], req[0], 0); + } MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { // output data type switch - if (kWriteTo == req[0]) { - Fill(s, outputs[0], req[0], 0); - } MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // indices data type switch mxnet_op::Kernel::Launch( s, N, req[0], N, M, K, strides, outputs[0].dptr(), @@ -1143,6 +1143,64 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs, }); } +template +inline typename std::enable_if<(!std::is_same::value), void>::type +GatherNDBackwardImpl(int N, int M, int K, + const mshadow::Shape<10> strides, + DType* out, + const DType* data, + const IType* indices, + mshadow::Stream *s); + +template +inline typename std::enable_if::value, void>::type +GatherNDBackwardImpl(int N, int M, int K, + const mshadow::Shape<10> strides, + DType* out, + const DType* data, + const IType* indices, + mshadow::Stream *s); + +template +inline void GatherNDBackwardImpl(int N, int M, int K, + const mshadow::Shape<10> strides, + DType* out, + const DType* data, + const IType* indices, + mshadow::Stream *s); + +template +void GatherNDBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + if (req[0] == kNullOp) return; + mshadow::Stream *s = ctx.get_stream(); + const TShape& oshape = outputs[0].shape_; + const TShape& ishape = inputs[1].shape_; + int M = ishape[0]; + int N = ishape.Size() / M; + int K = oshape.ProdShape(M, oshape.ndim()); + mshadow::Shape<10> strides; + for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride; + if (kWriteTo == req[0]) { + Fill(s, outputs[0], req[0], 0); + } + MXNET_NO_INT8_TYPE_SWITCH(inputs[0].type_flag_, DType, { // output data type switch + MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // indices data type switch + GatherNDBackwardImpl(N, M, K, strides, + outputs[0].dptr(), + inputs[0].dptr(), + inputs[1].dptr(), + s); + }); + }); +} + /*! * This is for internal use only. * DO NOT call this function unless you have to. diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 3fbf98becc8a..56dc27c4938f 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4391,21 +4391,32 @@ def check(data, idx): npdata = np.zeros_like(data.asnumpy()) npdata[npidx] = y.asnumpy() assert (npdata == data.grad.asnumpy()).all() - assert (mx.nd.scatter_nd(y, idx, shape=data.shape).asnumpy() == data.grad.asnumpy()).all() - - data = mx.nd.arange(360, dtype='int32').reshape((3,4,5,6)) - idx = mx.nd.array([[1,1,2], [3, 3, 0], [3,2,1]], dtype='int32') - - check(data, idx) - - idx = mx.nd.array([[1,1,2], [3,3,0], [3,2,1], [5,2,4]], dtype='int32') - - check(data, idx) - - data = mx.nd.array([2, 3, 0]) - idx = mx.nd.array([[1, 1, 0], [0, 1, 0]]) - - assert (mx.nd.scatter_nd(data, idx, shape=(2, 2)).asnumpy() == [[0, 0], [2, 3]]).all() + assert (mx.nd._internal._backward_gather_nd(y, idx, shape=data.shape).asnumpy() == data.grad.asnumpy()).all() + for dtype in ['int32', 'int64', 'float16', 'float32', 'float64']: + data = mx.nd.arange(360, dtype=dtype).reshape((3,4,5,6)) + idx = mx.nd.array([[1,1,2], [3, 3, 0], [3,2,1]], dtype='int32') + check(data, idx) + + idx = mx.nd.array([[1,1,2], [3,3,0], [3,2,1], [5,2,4]], dtype='int32') + + check(data, idx) + + data = mx.nd.array([2, 3, 0], dtype=dtype) + idx = mx.nd.array([[1, 1, 0], [0, 1, 0]], dtype='int32') + assert (mx.nd.scatter_nd(data, idx, shape=(2, 2)).asnumpy() == [[0, 0], [2, 3]]).all() + + data = mx.nd.array([2, 3, 0], dtype=dtype) + idx = mx.nd.array([[1, 1, 0], [1, 1, 0]], dtype='int32') + assert (mx.nd._internal._backward_gather_nd(data, idx, shape=(2, 2)).asnumpy() == [[0, 0], [0, 5]]).all() + data_npy = np.random.randint(0, 10, (100,)) + data = mx.nd.array(data_npy, dtype=dtype) + idx = mx.nd.zeros(shape=(1, 100), dtype='int32') + assert (mx.nd._internal._backward_gather_nd(data, idx, shape=(1,)).asscalar() == data_npy.sum()) + if dtype == 'int64': + data = mx.nd.array([2123162361283621, -31231236374787, + -112372937128970, -1378278798172378], dtype=dtype) + idx = mx.nd.array([[0, 0, 0, 0]], dtype='int32') + assert (mx.nd._internal._backward_gather_nd(data, idx, shape=(1,)).asscalar() == data.asnumpy().sum()) def compare_forw_backw_unary_op( name, forward_mxnet_call, forward_numpy_call,