Skip to content

Commit

Permalink
Fix the gradient of gather_nd (apache#9200)
Browse files Browse the repository at this point in the history
* try to implement scatter_nd_acc

fix

fix

fix

update

only support real_type

update

update

try to fix

update

fix

update

revise test

fix lint

* fix

* mark line as no lint

* fix test

* revise test

* fix test case

* revise

* remove openmp

* update

* update

* update

* update test

* Revert "update test"

This reverts commit 3eb3ac6.

* Revert "update"

This reverts commit a28fa53.

* Revert "update"

This reverts commit e99ffd0.

* Revert "update"

This reverts commit 399ba02.

* add atomic and specialize the behavior of half_t

* use "!" instead of not

* add test

* fix test

* fix test

* fix test

* rename to backward_gather_nd

* fix

* fix

* fix doc
  • Loading branch information
sxjscience authored and piiswrong committed Jan 4, 2018
1 parent d6ff2d4 commit 0394920
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 24 deletions.
5 changes: 5 additions & 0 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,11 @@ static inline __device__ void atomicAdd(mshadow::half::half_t *address,
} while (assumed != old);
}

// Overload atomicAdd to work for signed int64 on all architectures
static inline __device__ void atomicAdd(int64_t *address, int64_t val) {
atomicAdd(reinterpret_cast<unsigned long long*>(address), static_cast<unsigned long long>(val)); // NOLINT
}

template <typename DType>
__device__ inline DType ldg(const DType* address) {
#if __CUDA_ARCH__ >= 350
Expand Down
44 changes: 44 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,50 @@ inline int get_num_threads<cpu>(const int N) {
LOG(FATAL) << "ndim=" << NDim << "too large "; \
}

#define MXNET_NO_INT8_TYPE_SWITCH(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
LOG(FATAL) << "This operation does not " \
"support int8 or uint8"; \
break; \
case mshadow::kInt8: \
LOG(FATAL) << "This operation does not " \
"support int8 or uint8"; \
break; \
case mshadow::kInt32: \
{ \
typedef int32_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt64: \
{ \
typedef int64_t DType; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}


/*!
* \brief assign the val to out according
Expand Down
118 changes: 112 additions & 6 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,46 @@ inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const OpContext& ctx,
}


template<typename DType, typename IType>
inline typename std::enable_if<(!std::is_same<DType, mshadow::half::half_t>::value), void>::type
GatherNDBackwardImpl(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<cpu> *s) {
#pragma omp parallel for
for (int i = 0; i < N; i++) {
int offset = 0;
for (int j = 0; j < M; ++j) {
offset += strides[j] * static_cast<int>(indices[j*N + i]);
}
for (int j = 0; j < K; ++j) {
#pragma omp atomic
out[offset + j] += data[i * K + j];
}
}
}

template<typename DType, typename IType>
inline typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value, void>::type
GatherNDBackwardImpl(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<cpu> *s) {
for (int i = 0; i < N; i++) {
int offset = 0;
for (int j = 0; j < M; ++j) {
offset += strides[j] * static_cast<int>(indices[j*N + i]);
}
for (int j = 0; j < K; ++j) {
out[offset + j] += data[i * K + j];
}
}
}

DMLC_REGISTER_PARAMETER(EmbeddingParam);
DMLC_REGISTER_PARAMETER(TakeParam);
DMLC_REGISTER_PARAMETER(OneHotParam);
Expand Down Expand Up @@ -443,8 +483,7 @@ Examples::

NNVM_REGISTER_OP(gather_nd)
.describe(R"code(Gather elements or slices from `data` and store to a tensor whose
shape is defined by `indices`. `gather_nd` and `scatter_nd` are inverse functions
to each other.
shape is defined by `indices`.
Given `data` with shape `(X_0, X_1, ..., X_{N-1})` and indices with shape
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})`,
Expand Down Expand Up @@ -476,13 +515,14 @@ Examples::
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("scatter_nd");
p->attrs.op = nnvm::Op::Get("_backward_gather_nd");
p->attrs.name = n->attrs.name + "_backward";
p->inputs.push_back(ograds[0]);
p->inputs.push_back(n->inputs[1]);
p->control_deps.emplace_back(n);
auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices",
{n->inputs[1]}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
ret.emplace_back(nnvm::NodeEntry{zero, 0, 0});
Expand All @@ -492,10 +532,8 @@ Examples::
.add_argument("data", "NDArray-or-Symbol", "data")
.add_argument("indices", "NDArray-or-Symbol", "indices");


NNVM_REGISTER_OP(scatter_nd)
.describe(R"code(Scatters data into a new tensor according to indices.
`gather_nd` and `scatter_nd` are inverse functions to each other.
Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`,
Expand All @@ -510,6 +548,12 @@ The elements in output is defined as follows::
all other entries in output are 0.
.. warning::
If the indices have duplicates, the result will be non-deterministic and
the gradient of `scatter_nd` will not be correct!!
Examples::
data = [2, 3, 0]
Expand Down Expand Up @@ -548,11 +592,73 @@ Examples::
.add_argument("indices", "NDArray-or-Symbol", "indices")
.add_arguments(ScatterNDParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_gather_nd)
.describe(R"code(Accumulates data according to indices and get the result. It's the backward of
`gather_nd`.
Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`,
where `M <= N`. If `M == N`, data shape should simply be `(Y_0, ..., Y_{K-1})`.
The elements in output is defined as follows::
output[indices[0, y_0, ..., y_{K-1}],
...,
indices[M-1, y_0, ..., y_{K-1}],
x_M, ..., x_{N-1}] += data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
all other entries in output are 0 or the original value if AddTo is triggered.
Examples::
data = [2, 3, 0]
indices = [[1, 1, 0], [0, 1, 0]]
shape = (2, 2)
_backward_gather_nd(data, indices, shape) = [[0, 0], [2, 3]] # Same as scatter_nd
# The difference between scatter_nd and scatter_nd_acc is the latter will accumulate
# the values that point to the same index.
data = [2, 3, 0]
indices = [[1, 1, 0], [1, 1, 0]]
shape = (2, 2)
_backward_gather_nd(data, indices, shape) = [[0, 0], [0, 5]]
)code")
.set_num_outputs(1)
.set_num_inputs(2)
.set_attr_parser(ParamParser<ScatterNDParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "indices"};
})
.set_attr<nnvm::FInferShape>("FInferShape", ScatterNDShape)
.set_attr<nnvm::FInferType>("FInferType", ScatterNDType)
.set_attr<FCompute>("FCompute<cpu>", GatherNDBackward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("gather_nd");
p->attrs.name = n->attrs.name + "_backward";
p->inputs.push_back(ograds[0]);
p->inputs.push_back(n->inputs[1]);
p->control_deps.emplace_back(n);
auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices",
{n->inputs[1]}, nullptr, &n);
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
ret.emplace_back(nnvm::NodeEntry{zero, 0, 0});
return ret;
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.add_argument("data", "NDArray-or-Symbol", "data")
.add_argument("indices", "NDArray-or-Symbol", "indices")
.add_arguments(ScatterNDParam::__FIELDS__());

NNVM_REGISTER_OP(_scatter_set_nd)
.describe(R"code(This operator has the same functionality as scatter_nd
except that it does not reset the elements not indexed by the input
index `NDArray` in the input data `NDArray`.
.. note:: This operator is for internal use only.
Examples::
Expand Down
29 changes: 29 additions & 0 deletions src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,32 @@ inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const OpContext& ctx,
});
}

struct backward_gather_nd_gpu {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out, const DType* data,
const IType* indices) {
int offset = 0;
for (int j = 0; j < M; ++j) {
offset += strides[j] * static_cast<int>(indices[j*N + i]);
}
for (int j = 0; j < K; ++j) {
atomicAdd(out + (offset + j), data[i * K + j]);
}
}
};

template<typename DType, typename IType>
inline void GatherNDBackwardImpl(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<gpu> *s) {
mxnet_op::Kernel<backward_gather_nd_gpu, gpu>::Launch(s, N, N, M, K, strides, out, data, indices);
}

NNVM_REGISTER_OP(Embedding)
.set_attr<FCompute>("FCompute<gpu>", EmbeddingOpForward<gpu>);

Expand Down Expand Up @@ -209,6 +235,9 @@ NNVM_REGISTER_OP(gather_nd)
NNVM_REGISTER_OP(scatter_nd)
.set_attr<FCompute>("FCompute<gpu>", ScatterNDForward<gpu>);

NNVM_REGISTER_OP(_backward_gather_nd)
.set_attr<FCompute>("FCompute<gpu>", GatherNDBackward<gpu>);

NNVM_REGISTER_OP(_scatter_set_nd)
.set_attr<FCompute>("FCompute<gpu>", ScatterSetNDForward<gpu>);
} // namespace op
Expand Down
64 changes: 61 additions & 3 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1131,10 +1131,10 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs,
int K = oshape.ProdShape(M, oshape.ndim());
mshadow::Shape<10> strides;
for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride;
if (kWriteTo == req[0]) {
Fill<true>(s, outputs[0], req[0], 0);
}
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { // output data type switch
if (kWriteTo == req[0]) {
Fill<true>(s, outputs[0], req[0], 0);
}
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // indices data type switch
mxnet_op::Kernel<scatter_nd, xpu>::Launch(
s, N, req[0], N, M, K, strides, outputs[0].dptr<DType>(),
Expand All @@ -1143,6 +1143,64 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs,
});
}

template<typename DType, typename IType>
inline typename std::enable_if<(!std::is_same<DType, mshadow::half::half_t>::value), void>::type
GatherNDBackwardImpl(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<cpu> *s);

template<typename DType, typename IType>
inline typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value, void>::type
GatherNDBackwardImpl(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<cpu> *s);

template<typename DType, typename IType>
inline void GatherNDBackwardImpl(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<gpu> *s);

template<typename xpu>
void GatherNDBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
if (req[0] == kNullOp) return;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TShape& oshape = outputs[0].shape_;
const TShape& ishape = inputs[1].shape_;
int M = ishape[0];
int N = ishape.Size() / M;
int K = oshape.ProdShape(M, oshape.ndim());
mshadow::Shape<10> strides;
for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride;
if (kWriteTo == req[0]) {
Fill<true>(s, outputs[0], req[0], 0);
}
MXNET_NO_INT8_TYPE_SWITCH(inputs[0].type_flag_, DType, { // output data type switch
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // indices data type switch
GatherNDBackwardImpl(N, M, K, strides,
outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
inputs[1].dptr<IType>(),
s);
});
});
}

/*!
* This is for internal use only.
* DO NOT call this function unless you have to.
Expand Down
Loading

0 comments on commit 0394920

Please sign in to comment.