This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Fix the gradient of gather_nd #9200
Merged
Merged
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
1bf688b
try to implement scatter_nd_acc
sxjscience 3469c19
fix
sxjscience b8f420a
mark line as no lint
sxjscience 0e56e75
fix test
sxjscience 37b2d32
revise test
sxjscience 5760392
fix test case
sxjscience d869a85
revise
sxjscience 5cf2d8a
remove openmp
sxjscience 399ba02
update
sxjscience e99ffd0
update
sxjscience a28fa53
update
sxjscience 3eb3ac6
update test
sxjscience 6b8c348
Revert "update test"
sxjscience 026e709
Revert "update"
sxjscience 0c9cddc
Revert "update"
sxjscience 8ca6c8f
Revert "update"
sxjscience 6ccf08b
add atomic and specialize the behavior of half_t
sxjscience 808d331
use "!" instead of not
sxjscience 7c3c327
add test
sxjscience 7be9821
fix test
sxjscience 4e19b0b
fix test
sxjscience 04375dd
fix test
sxjscience b1f5fc3
rename to backward_gather_nd
sxjscience ec2f169
fix
sxjscience 4d0b31a
fix
sxjscience f6ffa4d
fix doc
sxjscience File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -137,6 +137,46 @@ inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const OpContext& ctx, | |
} | ||
|
||
|
||
template<typename DType, typename IType> | ||
inline typename std::enable_if<(!std::is_same<DType, mshadow::half::half_t>::value), void>::type | ||
GatherNDBackwardImpl(int N, int M, int K, | ||
const mshadow::Shape<10> strides, | ||
DType* out, | ||
const DType* data, | ||
const IType* indices, | ||
mshadow::Stream<cpu> *s) { | ||
#pragma omp parallel for | ||
for (int i = 0; i < N; i++) { | ||
int offset = 0; | ||
for (int j = 0; j < M; ++j) { | ||
offset += strides[j] * static_cast<int>(indices[j*N + i]); | ||
} | ||
for (int j = 0; j < K; ++j) { | ||
#pragma omp atomic | ||
out[offset + j] += data[i * K + j]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can consolidate this with the gpu kernel by using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @reminisce I've specialized the implementation of half_t and now it passes the test |
||
} | ||
} | ||
} | ||
|
||
template<typename DType, typename IType> | ||
inline typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value, void>::type | ||
GatherNDBackwardImpl(int N, int M, int K, | ||
const mshadow::Shape<10> strides, | ||
DType* out, | ||
const DType* data, | ||
const IType* indices, | ||
mshadow::Stream<cpu> *s) { | ||
for (int i = 0; i < N; i++) { | ||
int offset = 0; | ||
for (int j = 0; j < M; ++j) { | ||
offset += strides[j] * static_cast<int>(indices[j*N + i]); | ||
} | ||
for (int j = 0; j < K; ++j) { | ||
out[offset + j] += data[i * K + j]; | ||
} | ||
} | ||
} | ||
|
||
DMLC_REGISTER_PARAMETER(EmbeddingParam); | ||
DMLC_REGISTER_PARAMETER(TakeParam); | ||
DMLC_REGISTER_PARAMETER(OneHotParam); | ||
|
@@ -443,8 +483,7 @@ Examples:: | |
|
||
NNVM_REGISTER_OP(gather_nd) | ||
.describe(R"code(Gather elements or slices from `data` and store to a tensor whose | ||
shape is defined by `indices`. `gather_nd` and `scatter_nd` are inverse functions | ||
to each other. | ||
shape is defined by `indices`. | ||
|
||
Given `data` with shape `(X_0, X_1, ..., X_{N-1})` and indices with shape | ||
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})`, | ||
|
@@ -476,13 +515,14 @@ Examples:: | |
.set_attr<nnvm::FGradient>("FGradient", | ||
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { | ||
auto p = nnvm::Node::Create(); | ||
p->attrs.op = nnvm::Op::Get("scatter_nd"); | ||
p->attrs.op = nnvm::Op::Get("_backward_gather_nd"); | ||
p->attrs.name = n->attrs.name + "_backward"; | ||
p->inputs.push_back(ograds[0]); | ||
p->inputs.push_back(n->inputs[1]); | ||
p->control_deps.emplace_back(n); | ||
auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices", | ||
{n->inputs[1]}, nullptr, &n); | ||
|
||
std::vector<nnvm::NodeEntry> ret; | ||
ret.emplace_back(nnvm::NodeEntry{p, 0, 0}); | ||
ret.emplace_back(nnvm::NodeEntry{zero, 0, 0}); | ||
|
@@ -492,10 +532,8 @@ Examples:: | |
.add_argument("data", "NDArray-or-Symbol", "data") | ||
.add_argument("indices", "NDArray-or-Symbol", "indices"); | ||
|
||
|
||
NNVM_REGISTER_OP(scatter_nd) | ||
.describe(R"code(Scatters data into a new tensor according to indices. | ||
`gather_nd` and `scatter_nd` are inverse functions to each other. | ||
|
||
Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape | ||
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`, | ||
|
@@ -510,6 +548,12 @@ The elements in output is defined as follows:: | |
|
||
all other entries in output are 0. | ||
|
||
.. warning:: | ||
|
||
If the indices have duplicates, the result will be non-deterministic and | ||
the gradient of `scatter_nd` will not be correct!! | ||
|
||
|
||
Examples:: | ||
|
||
data = [2, 3, 0] | ||
|
@@ -548,11 +592,73 @@ Examples:: | |
.add_argument("indices", "NDArray-or-Symbol", "indices") | ||
.add_arguments(ScatterNDParam::__FIELDS__()); | ||
|
||
NNVM_REGISTER_OP(_backward_gather_nd) | ||
.describe(R"code(Accumulates data according to indices and get the result. It's the backward of | ||
`gather_nd`. | ||
|
||
Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape | ||
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`, | ||
where `M <= N`. If `M == N`, data shape should simply be `(Y_0, ..., Y_{K-1})`. | ||
|
||
The elements in output is defined as follows:: | ||
|
||
output[indices[0, y_0, ..., y_{K-1}], | ||
..., | ||
indices[M-1, y_0, ..., y_{K-1}], | ||
x_M, ..., x_{N-1}] += data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] | ||
|
||
all other entries in output are 0 or the original value if AddTo is triggered. | ||
|
||
Examples:: | ||
|
||
data = [2, 3, 0] | ||
indices = [[1, 1, 0], [0, 1, 0]] | ||
shape = (2, 2) | ||
_backward_gather_nd(data, indices, shape) = [[0, 0], [2, 3]] # Same as scatter_nd | ||
|
||
# The difference between scatter_nd and scatter_nd_acc is the latter will accumulate | ||
# the values that point to the same index. | ||
|
||
data = [2, 3, 0] | ||
indices = [[1, 1, 0], [1, 1, 0]] | ||
shape = (2, 2) | ||
_backward_gather_nd(data, indices, shape) = [[0, 0], [0, 5]] | ||
|
||
)code") | ||
.set_num_outputs(1) | ||
.set_num_inputs(2) | ||
.set_attr_parser(ParamParser<ScatterNDParam>) | ||
.set_attr<nnvm::FListInputNames>("FListInputNames", | ||
[](const NodeAttrs& attrs) { | ||
return std::vector<std::string>{"data", "indices"}; | ||
}) | ||
.set_attr<nnvm::FInferShape>("FInferShape", ScatterNDShape) | ||
.set_attr<nnvm::FInferType>("FInferType", ScatterNDType) | ||
.set_attr<FCompute>("FCompute<cpu>", GatherNDBackward<cpu>) | ||
.set_attr<nnvm::FGradient>("FGradient", | ||
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { | ||
auto p = nnvm::Node::Create(); | ||
p->attrs.op = nnvm::Op::Get("gather_nd"); | ||
p->attrs.name = n->attrs.name + "_backward"; | ||
p->inputs.push_back(ograds[0]); | ||
p->inputs.push_back(n->inputs[1]); | ||
p->control_deps.emplace_back(n); | ||
auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices", | ||
{n->inputs[1]}, nullptr, &n); | ||
std::vector<nnvm::NodeEntry> ret; | ||
ret.emplace_back(nnvm::NodeEntry{p, 0, 0}); | ||
ret.emplace_back(nnvm::NodeEntry{zero, 0, 0}); | ||
return ret; | ||
}) | ||
.set_attr<nnvm::TIsBackward>("TIsBackward", true) | ||
.add_argument("data", "NDArray-or-Symbol", "data") | ||
.add_argument("indices", "NDArray-or-Symbol", "indices") | ||
.add_arguments(ScatterNDParam::__FIELDS__()); | ||
|
||
NNVM_REGISTER_OP(_scatter_set_nd) | ||
.describe(R"code(This operator has the same functionality as scatter_nd | ||
except that it does not reset the elements not indexed by the input | ||
index `NDArray` in the input data `NDArray`. | ||
|
||
.. note:: This operator is for internal use only. | ||
|
||
Examples:: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you sure this works for negative value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be safe if CUDA uses 2's complement to implement the signed long long.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found this: https://devtalk.nvidia.com/default/topic/506105/atomicadd-with-signed-long-long-not-working/