Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix the gradient of gather_nd #9200

Merged
merged 26 commits into from
Jan 4, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,11 @@ static inline __device__ void atomicAdd(mshadow::half::half_t *address,
} while (assumed != old);
}

// Overload atomicAdd to work for signed int64 on all architectures
static inline __device__ void atomicAdd(int64_t *address, int64_t val) {
atomicAdd(reinterpret_cast<unsigned long long*>(address), static_cast<unsigned long long>(val)); // NOLINT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure this works for negative value?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be safe if CUDA uses 2's complement to implement the signed long long.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

template <typename DType>
__device__ inline DType ldg(const DType* address) {
#if __CUDA_ARCH__ >= 350
Expand Down
44 changes: 44 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,50 @@ inline int get_num_threads<cpu>(const int N) {
LOG(FATAL) << "ndim=" << NDim << "too large "; \
}

#define MXNET_NO_INT8_TYPE_SWITCH(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
LOG(FATAL) << "This operation does not " \
"support int8 or uint8"; \
break; \
case mshadow::kInt8: \
LOG(FATAL) << "This operation does not " \
"support int8 or uint8"; \
break; \
case mshadow::kInt32: \
{ \
typedef int32_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt64: \
{ \
typedef int64_t DType; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}


/*!
* \brief assign the val to out according
Expand Down
92 changes: 87 additions & 5 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,24 @@ inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const OpContext& ctx,
}


template<typename DType, typename IType>
inline void ScatterNDAccForwardImpl(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<cpu> *s) {
for (int i = 0; i < N; i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is single-threaded. Can we use #pragma omp critical or #pragma omp atomic for the cpu kernel?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can use openmp. Let me have a try.

int offset = 0;
for (int j = 0; j < M; ++j) {
offset += strides[j] * static_cast<int>(indices[j*N + i]);
}
for (int j = 0; j < K; ++j) {
out[offset + j] += data[i * K + j];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can consolidate this with the gpu kernel by using #if __CUDA__ #elsein the header file since this line is the only difference between cpu and gpu kernels. Then in the FCompute function, you can use Kernel::Launch for both cpu and gpu kernels. That would make the implementation less verbose.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reminisce I've specialized the implementation of half_t and now it passes the test

}
}
}

DMLC_REGISTER_PARAMETER(EmbeddingParam);
DMLC_REGISTER_PARAMETER(TakeParam);
DMLC_REGISTER_PARAMETER(OneHotParam);
Expand Down Expand Up @@ -443,7 +461,7 @@ Examples::

NNVM_REGISTER_OP(gather_nd)
.describe(R"code(Gather elements or slices from `data` and store to a tensor whose
shape is defined by `indices`. `gather_nd` and `scatter_nd` are inverse functions
shape is defined by `indices`. `gather_nd` and `scatter_nd_acc` are inverse functions
to each other.

Given `data` with shape `(X_0, X_1, ..., X_{N-1})` and indices with shape
Expand Down Expand Up @@ -476,13 +494,14 @@ Examples::
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("scatter_nd");
p->attrs.op = nnvm::Op::Get("scatter_nd_acc");
p->attrs.name = n->attrs.name + "_backward";
p->inputs.push_back(ograds[0]);
p->inputs.push_back(n->inputs[1]);
p->control_deps.emplace_back(n);
auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices",
{n->inputs[1]}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
ret.emplace_back(nnvm::NodeEntry{zero, 0, 0});
Expand All @@ -492,10 +511,8 @@ Examples::
.add_argument("data", "NDArray-or-Symbol", "data")
.add_argument("indices", "NDArray-or-Symbol", "indices");


NNVM_REGISTER_OP(scatter_nd)
.describe(R"code(Scatters data into a new tensor according to indices.
`gather_nd` and `scatter_nd` are inverse functions to each other.

Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`,
Expand All @@ -510,6 +527,10 @@ The elements in output is defined as follows::

all other entries in output are 0.

WARNING!!! If the indices have duplicates, the result will be non-deterministic and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks ugly. Standard warning message is
.. Warning:: xxx

the gradient of `scatter_nd` will not be correct!!


Examples::

data = [2, 3, 0]
Expand Down Expand Up @@ -548,11 +569,72 @@ Examples::
.add_argument("indices", "NDArray-or-Symbol", "indices")
.add_arguments(ScatterNDParam::__FIELDS__());

NNVM_REGISTER_OP(scatter_nd_acc)
.describe(R"code(Accumulates data according to indices and get the result.

Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`,
where `M <= N`. If `M == N`, data shape should simply be `(Y_0, ..., Y_{K-1})`.

The elements in output is defined as follows::

output[indices[0, y_0, ..., y_{K-1}],
...,
indices[M-1, y_0, ..., y_{K-1}],
x_M, ..., x_{N-1}] += data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]

all other entries in output are 0 or the original value if AddTo is triggered.

Examples::

data = [2, 3, 0]
indices = [[1, 1, 0], [0, 1, 0]]
shape = (2, 2)
scatter_nd_acc(data, indices, shape) = [[0, 0], [2, 3]] # Same as scatter_nd

# The difference between scatter_nd and scatter_nd_acc is the latter will accumulate
# the values that point to the same index.

data = [2, 3, 0]
indices = [[1, 1, 0], [1, 1, 0]]
shape = (2, 2)
scatter_nd_acc(data, indices, shape) = [[0, 0], [0, 5]]

)code")
.set_num_outputs(1)
.set_num_inputs(2)
.set_attr_parser(ParamParser<ScatterNDParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "indices"};
})
.set_attr<nnvm::FInferShape>("FInferShape", ScatterNDShape)
.set_attr<nnvm::FInferType>("FInferType", ScatterNDType)
.set_attr<FCompute>("FCompute<cpu>", ScatterNDAccForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("gather_nd");
p->attrs.name = n->attrs.name + "_backward";
p->inputs.push_back(ograds[0]);
p->inputs.push_back(n->inputs[1]);
p->control_deps.emplace_back(n);
auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices",
{n->inputs[1]}, nullptr, &n);
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
ret.emplace_back(nnvm::NodeEntry{zero, 0, 0});
return ret;
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.add_argument("data", "NDArray-or-Symbol", "data")
.add_argument("indices", "NDArray-or-Symbol", "indices")
.add_arguments(ScatterNDParam::__FIELDS__());

NNVM_REGISTER_OP(_scatter_set_nd)
.describe(R"code(This operator has the same functionality as scatter_nd
except that it does not reset the elements not indexed by the input
index `NDArray` in the input data `NDArray`.

.. note:: This operator is for internal use only.

Examples::
Expand Down
34 changes: 34 additions & 0 deletions src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,37 @@ inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const OpContext& ctx,
});
}

template<typename DType, typename IType>
__global__ void ScatterNDAccForwardImplKernel(int N, int M, int K,
const mshadow::Shape<10> strides,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indent.

DType* out,
const DType* data,
const IType* indices) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
int offset = 0;
for (int j = 0; j < M; ++j) {
offset += strides[j] * static_cast<int>(indices[j*N + i]);
}
for (int j = 0; j < K; ++j) {
atomicAdd(out + (offset + j), data[i * K + j]);
}
}
}

template<typename DType, typename IType>
inline void ScatterNDAccForwardImpl(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<gpu> *s) {
using namespace mshadow::cuda;
int ngrid = std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum);
ScatterNDAccForwardImplKernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Kernel::Launch not fit here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not fit due to the atomicAdd.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does atomicAdd prevent Kernel::Launch from being used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I can still use launch, but can only use it for GPU.

<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s) >>>(
N, M, K, strides, out, data, indices);
}

NNVM_REGISTER_OP(Embedding)
.set_attr<FCompute>("FCompute<gpu>", EmbeddingOpForward<gpu>);

Expand Down Expand Up @@ -209,6 +240,9 @@ NNVM_REGISTER_OP(gather_nd)
NNVM_REGISTER_OP(scatter_nd)
.set_attr<FCompute>("FCompute<gpu>", ScatterNDForward<gpu>);

NNVM_REGISTER_OP(scatter_nd_acc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string acc looks ambiguous. I thought it standed for accurate in the beginning, but realized that it means accumulate later. It's named scatter_nd_add in TF, as there are also scatter_nd_sub, scatter_nd_mul, and scatter_nd_div. Shall we also call it scatter_nd_add to be precise?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a slight difference between scatter_nd_add and scatter_nd_acc. In scatter_nd_add, the results are added to another array. While in scatter_nd_acc, the values are added to a all-zero array. The number of arguments are different for these two OPs.

.set_attr<FCompute>("FCompute<gpu>", ScatterNDAccForward<gpu>);

NNVM_REGISTER_OP(_scatter_set_nd)
.set_attr<FCompute>("FCompute<gpu>", ScatterSetNDForward<gpu>);
} // namespace op
Expand Down
54 changes: 51 additions & 3 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1131,10 +1131,10 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs,
int K = oshape.ProdShape(M, oshape.ndim());
mshadow::Shape<10> strides;
for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride;
if (kWriteTo == req[0]) {
Fill<true>(s, outputs[0], req[0], 0);
}
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { // output data type switch
if (kWriteTo == req[0]) {
Fill<true>(s, outputs[0], req[0], 0);
}
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // indices data type switch
mxnet_op::Kernel<scatter_nd, xpu>::Launch(
s, N, req[0], N, M, K, strides, outputs[0].dptr<DType>(),
Expand All @@ -1143,6 +1143,54 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs,
});
}

template<typename DType, typename IType>
inline void ScatterNDAccForwardImpl(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<cpu> *s);

template<typename DType, typename IType>
inline void ScatterNDAccForwardImpl(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<gpu> *s);

template<typename xpu>
void ScatterNDAccForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
if (req[0] == kNullOp) return;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TShape& oshape = outputs[0].shape_;
const TShape& ishape = inputs[1].shape_;
int M = ishape[0];
int N = ishape.Size() / M;
int K = oshape.ProdShape(M, oshape.ndim());
mshadow::Shape<10> strides;
for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride;
if (kWriteTo == req[0]) {
Fill<true>(s, outputs[0], req[0], 0);
}
MXNET_NO_INT8_TYPE_SWITCH(inputs[0].type_flag_, DType, { // output data type switch
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // indices data type switch
ScatterNDAccForwardImpl(N, M, K, strides,
outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
inputs[1].dptr<IType>(),
s);
});
});
}

/*!
* This is for internal use only.
* DO NOT call this function unless you have to.
Expand Down
24 changes: 13 additions & 11 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4315,21 +4315,23 @@ def check(data, idx):
npdata = np.zeros_like(data.asnumpy())
npdata[npidx] = y.asnumpy()
assert (npdata == data.grad.asnumpy()).all()
assert (mx.nd.scatter_nd(y, idx, shape=data.shape).asnumpy() == data.grad.asnumpy()).all()
assert (mx.nd.scatter_nd_acc(y, idx, shape=data.shape).asnumpy() == data.grad.asnumpy()).all()
for dtype in ['int32', 'float16', 'float32', 'float64']:
data = mx.nd.arange(360, dtype=dtype).reshape((3,4,5,6))
idx = mx.nd.array([[1,1,2], [3, 3, 0], [3,2,1]], dtype='int32')
check(data, idx)

data = mx.nd.arange(360, dtype='int32').reshape((3,4,5,6))
idx = mx.nd.array([[1,1,2], [3, 3, 0], [3,2,1]], dtype='int32')
idx = mx.nd.array([[1,1,2], [3,3,0], [3,2,1], [5,2,4]], dtype='int32')

check(data, idx)
check(data, idx)

idx = mx.nd.array([[1,1,2], [3,3,0], [3,2,1], [5,2,4]], dtype='int32')
data = mx.nd.array([2, 3, 0], dtype=dtype)
idx = mx.nd.array([[1, 1, 0], [0, 1, 0]], dtype=dtype)
assert (mx.nd.scatter_nd(data, idx, shape=(2, 2)).asnumpy() == [[0, 0], [2, 3]]).all()

check(data, idx)

data = mx.nd.array([2, 3, 0])
idx = mx.nd.array([[1, 1, 0], [0, 1, 0]])

assert (mx.nd.scatter_nd(data, idx, shape=(2, 2)).asnumpy() == [[0, 0], [2, 3]]).all()
data = mx.nd.array([2, -3, 0], dtype=dtype)
idx = mx.nd.array([[1, 1, 0], [1, 1, 0]], dtype=dtype)
assert (mx.nd.scatter_nd_acc(data, idx, shape=(2, 2)).asnumpy() == [[0, 0], [0, -1]]).all()

def compare_forw_backw_unary_op(
name, forward_mxnet_call, forward_numpy_call,
Expand Down