-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Fix the gradient of gather_nd #9200
Changes from 4 commits
1bf688b
3469c19
b8f420a
0e56e75
37b2d32
5760392
d869a85
5cf2d8a
399ba02
e99ffd0
a28fa53
3eb3ac6
6b8c348
026e709
0c9cddc
8ca6c8f
6ccf08b
808d331
7c3c327
7be9821
4e19b0b
04375dd
b1f5fc3
ec2f169
4d0b31a
f6ffa4d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -137,6 +137,24 @@ inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const OpContext& ctx, | |
} | ||
|
||
|
||
template<typename DType, typename IType> | ||
inline void ScatterNDAccForwardImpl(int N, int M, int K, | ||
const mshadow::Shape<10> strides, | ||
DType* out, | ||
const DType* data, | ||
const IType* indices, | ||
mshadow::Stream<cpu> *s) { | ||
for (int i = 0; i < N; i++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is single-threaded. Can we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we can use openmp. Let me have a try. |
||
int offset = 0; | ||
for (int j = 0; j < M; ++j) { | ||
offset += strides[j] * static_cast<int>(indices[j*N + i]); | ||
} | ||
for (int j = 0; j < K; ++j) { | ||
out[offset + j] += data[i * K + j]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can consolidate this with the gpu kernel by using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @reminisce I've specialized the implementation of half_t and now it passes the test |
||
} | ||
} | ||
} | ||
|
||
DMLC_REGISTER_PARAMETER(EmbeddingParam); | ||
DMLC_REGISTER_PARAMETER(TakeParam); | ||
DMLC_REGISTER_PARAMETER(OneHotParam); | ||
|
@@ -443,7 +461,7 @@ Examples:: | |
|
||
NNVM_REGISTER_OP(gather_nd) | ||
.describe(R"code(Gather elements or slices from `data` and store to a tensor whose | ||
shape is defined by `indices`. `gather_nd` and `scatter_nd` are inverse functions | ||
shape is defined by `indices`. `gather_nd` and `scatter_nd_acc` are inverse functions | ||
to each other. | ||
|
||
Given `data` with shape `(X_0, X_1, ..., X_{N-1})` and indices with shape | ||
|
@@ -476,13 +494,14 @@ Examples:: | |
.set_attr<nnvm::FGradient>("FGradient", | ||
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { | ||
auto p = nnvm::Node::Create(); | ||
p->attrs.op = nnvm::Op::Get("scatter_nd"); | ||
p->attrs.op = nnvm::Op::Get("scatter_nd_acc"); | ||
p->attrs.name = n->attrs.name + "_backward"; | ||
p->inputs.push_back(ograds[0]); | ||
p->inputs.push_back(n->inputs[1]); | ||
p->control_deps.emplace_back(n); | ||
auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices", | ||
{n->inputs[1]}, nullptr, &n); | ||
|
||
std::vector<nnvm::NodeEntry> ret; | ||
ret.emplace_back(nnvm::NodeEntry{p, 0, 0}); | ||
ret.emplace_back(nnvm::NodeEntry{zero, 0, 0}); | ||
|
@@ -492,10 +511,8 @@ Examples:: | |
.add_argument("data", "NDArray-or-Symbol", "data") | ||
.add_argument("indices", "NDArray-or-Symbol", "indices"); | ||
|
||
|
||
NNVM_REGISTER_OP(scatter_nd) | ||
.describe(R"code(Scatters data into a new tensor according to indices. | ||
`gather_nd` and `scatter_nd` are inverse functions to each other. | ||
|
||
Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape | ||
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`, | ||
|
@@ -510,6 +527,10 @@ The elements in output is defined as follows:: | |
|
||
all other entries in output are 0. | ||
|
||
WARNING!!! If the indices have duplicates, the result will be non-deterministic and | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks ugly. Standard warning message is |
||
the gradient of `scatter_nd` will not be correct!! | ||
|
||
|
||
Examples:: | ||
|
||
data = [2, 3, 0] | ||
|
@@ -548,11 +569,72 @@ Examples:: | |
.add_argument("indices", "NDArray-or-Symbol", "indices") | ||
.add_arguments(ScatterNDParam::__FIELDS__()); | ||
|
||
NNVM_REGISTER_OP(scatter_nd_acc) | ||
.describe(R"code(Accumulates data according to indices and get the result. | ||
|
||
Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape | ||
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`, | ||
where `M <= N`. If `M == N`, data shape should simply be `(Y_0, ..., Y_{K-1})`. | ||
|
||
The elements in output is defined as follows:: | ||
|
||
output[indices[0, y_0, ..., y_{K-1}], | ||
..., | ||
indices[M-1, y_0, ..., y_{K-1}], | ||
x_M, ..., x_{N-1}] += data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] | ||
|
||
all other entries in output are 0 or the original value if AddTo is triggered. | ||
|
||
Examples:: | ||
|
||
data = [2, 3, 0] | ||
indices = [[1, 1, 0], [0, 1, 0]] | ||
shape = (2, 2) | ||
scatter_nd_acc(data, indices, shape) = [[0, 0], [2, 3]] # Same as scatter_nd | ||
|
||
# The difference between scatter_nd and scatter_nd_acc is the latter will accumulate | ||
# the values that point to the same index. | ||
|
||
data = [2, 3, 0] | ||
indices = [[1, 1, 0], [1, 1, 0]] | ||
shape = (2, 2) | ||
scatter_nd_acc(data, indices, shape) = [[0, 0], [0, 5]] | ||
|
||
)code") | ||
.set_num_outputs(1) | ||
.set_num_inputs(2) | ||
.set_attr_parser(ParamParser<ScatterNDParam>) | ||
.set_attr<nnvm::FListInputNames>("FListInputNames", | ||
[](const NodeAttrs& attrs) { | ||
return std::vector<std::string>{"data", "indices"}; | ||
}) | ||
.set_attr<nnvm::FInferShape>("FInferShape", ScatterNDShape) | ||
.set_attr<nnvm::FInferType>("FInferType", ScatterNDType) | ||
.set_attr<FCompute>("FCompute<cpu>", ScatterNDAccForward<cpu>) | ||
.set_attr<nnvm::FGradient>("FGradient", | ||
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { | ||
auto p = nnvm::Node::Create(); | ||
p->attrs.op = nnvm::Op::Get("gather_nd"); | ||
p->attrs.name = n->attrs.name + "_backward"; | ||
p->inputs.push_back(ograds[0]); | ||
p->inputs.push_back(n->inputs[1]); | ||
p->control_deps.emplace_back(n); | ||
auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices", | ||
{n->inputs[1]}, nullptr, &n); | ||
std::vector<nnvm::NodeEntry> ret; | ||
ret.emplace_back(nnvm::NodeEntry{p, 0, 0}); | ||
ret.emplace_back(nnvm::NodeEntry{zero, 0, 0}); | ||
return ret; | ||
}) | ||
.set_attr<nnvm::TIsBackward>("TIsBackward", true) | ||
.add_argument("data", "NDArray-or-Symbol", "data") | ||
.add_argument("indices", "NDArray-or-Symbol", "indices") | ||
.add_arguments(ScatterNDParam::__FIELDS__()); | ||
|
||
NNVM_REGISTER_OP(_scatter_set_nd) | ||
.describe(R"code(This operator has the same functionality as scatter_nd | ||
except that it does not reset the elements not indexed by the input | ||
index `NDArray` in the input data `NDArray`. | ||
|
||
.. note:: This operator is for internal use only. | ||
|
||
Examples:: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -179,6 +179,37 @@ inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const OpContext& ctx, | |
}); | ||
} | ||
|
||
template<typename DType, typename IType> | ||
__global__ void ScatterNDAccForwardImplKernel(int N, int M, int K, | ||
const mshadow::Shape<10> strides, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indent. |
||
DType* out, | ||
const DType* data, | ||
const IType* indices) { | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { | ||
int offset = 0; | ||
for (int j = 0; j < M; ++j) { | ||
offset += strides[j] * static_cast<int>(indices[j*N + i]); | ||
} | ||
for (int j = 0; j < K; ++j) { | ||
atomicAdd(out + (offset + j), data[i * K + j]); | ||
} | ||
} | ||
} | ||
|
||
template<typename DType, typename IType> | ||
inline void ScatterNDAccForwardImpl(int N, int M, int K, | ||
const mshadow::Shape<10> strides, | ||
DType* out, | ||
const DType* data, | ||
const IType* indices, | ||
mshadow::Stream<gpu> *s) { | ||
using namespace mshadow::cuda; | ||
int ngrid = std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum); | ||
ScatterNDAccForwardImplKernel | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does not fit due to the atomicAdd. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I can still use launch, but can only use it for GPU. |
||
<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s) >>>( | ||
N, M, K, strides, out, data, indices); | ||
} | ||
|
||
NNVM_REGISTER_OP(Embedding) | ||
.set_attr<FCompute>("FCompute<gpu>", EmbeddingOpForward<gpu>); | ||
|
||
|
@@ -209,6 +240,9 @@ NNVM_REGISTER_OP(gather_nd) | |
NNVM_REGISTER_OP(scatter_nd) | ||
.set_attr<FCompute>("FCompute<gpu>", ScatterNDForward<gpu>); | ||
|
||
NNVM_REGISTER_OP(scatter_nd_acc) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The string There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a slight difference between |
||
.set_attr<FCompute>("FCompute<gpu>", ScatterNDAccForward<gpu>); | ||
|
||
NNVM_REGISTER_OP(_scatter_set_nd) | ||
.set_attr<FCompute>("FCompute<gpu>", ScatterSetNDForward<gpu>); | ||
} // namespace op | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you sure this works for negative value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be safe if CUDA uses 2's complement to implement the signed long long.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found this: https://devtalk.nvidia.com/default/topic/506105/atomicadd-with-signed-long-long-not-working/