Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix the gradient of gather_nd #9200

Merged
merged 26 commits into from
Jan 4, 2018
Merged

Conversation

sxjscience
Copy link
Member

@sxjscience sxjscience commented Dec 26, 2017

Description

Add _backward_gather_nd, which accumulates the value when the indices are same. Should solve #9172

Checklist

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Code is well-documented:
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Add new type switch macro that can be used when int8 is not supported
  • atomicAdd support for int64
  • Add _backward_gather_nd
  • Set the gradient of gather_nd to _backward_gather_nd

Comments

I use atomicAdd to implement the operator. The current CPU implementation does not used openmp. Also, int8 and uint8 are not supported.

fix

fix

fix

update

only support real_type

update

update

try to fix

update

fix

update

revise test

fix lint
@piiswrong
Copy link
Contributor

@reminisce

@@ -479,6 +479,11 @@ static inline __device__ void atomicAdd(mshadow::half::half_t *address,
} while (assumed != old);
}

// Overload atomicAdd to work for signed int64 on all architectures
static inline __device__ void atomicAdd(int64_t *address, int64_t val) {
atomicAdd(reinterpret_cast<unsigned long long*>(address), static_cast<unsigned long long>(val)); // NOLINT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure this works for negative value?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be safe if CUDA uses 2's complement to implement the signed long long.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data = mx.nd.array([2123162361283621, -31231236374787,
-112372937128970, -1378278798172378], dtype=dtype)
idx = mx.nd.array([[0, 0, 0, 0]], dtype='int32')
assert (mx.nd.scatter_nd_acc(data, idx, shape=(1,)).asnumpy()[0] == data.asnumpy().sum())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong I've added another test case for the signed int64 case.

const DType* data,
const IType* indices,
mshadow::Stream<cpu> *s) {
for (int i = 0; i < N; i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is single-threaded. Can we use #pragma omp critical or #pragma omp atomic for the cpu kernel?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can use openmp. Let me have a try.

@@ -209,6 +240,9 @@ NNVM_REGISTER_OP(gather_nd)
NNVM_REGISTER_OP(scatter_nd)
.set_attr<FCompute>("FCompute<gpu>", ScatterNDForward<gpu>);

NNVM_REGISTER_OP(scatter_nd_acc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string acc looks ambiguous. I thought it standed for accurate in the beginning, but realized that it means accumulate later. It's named scatter_nd_add in TF, as there are also scatter_nd_sub, scatter_nd_mul, and scatter_nd_div. Shall we also call it scatter_nd_add to be precise?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a slight difference between scatter_nd_add and scatter_nd_acc. In scatter_nd_add, the results are added to another array. While in scatter_nd_acc, the values are added to a all-zero array. The number of arguments are different for these two OPs.

mshadow::Stream<gpu> *s) {
using namespace mshadow::cuda;
int ngrid = std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum);
ScatterNDAccForwardImplKernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Kernel::Launch not fit here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not fit due to the atomicAdd.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does atomicAdd prevent Kernel::Launch from being used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I can still use launch, but can only use it for GPU.

@@ -179,6 +179,37 @@ inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const OpContext& ctx,
});
}

template<typename DType, typename IType>
__global__ void ScatterNDAccForwardImplKernel(int N, int M, int K,
const mshadow::Shape<10> strides,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indent.

}
for (int j = 0; j < K; ++j) {
#pragma omp atomic
out[offset + j] += data[i * K + j];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can consolidate this with the gpu kernel by using #if __CUDA__ #elsein the header file since this line is the only difference between cpu and gpu kernels. Then in the FCompute function, you can use Kernel::Launch for both cpu and gpu kernels. That would make the implementation less verbose.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reminisce I've specialized the implementation of half_t and now it passes the test

@sxjscience
Copy link
Member Author

sxjscience commented Dec 28, 2017

It's very strange. The CI test fails on all windows machines.

This reverts commit 3eb3ac6.
This reverts commit a28fa53.
This reverts commit e99ffd0.
This reverts commit 399ba02.
@sxjscience
Copy link
Member Author

@reminisce I find I cannot use omp atomic. Also, using omp critic will not have any parallelism. I've reverted back to the original version.

@reminisce
Copy link
Contributor

What is the error of using omp atomic?

@sxjscience
Copy link
Member Author

sxjscience commented Dec 28, 2017

"#pragma omp atomic" has improper form on Windows
invalid expression type for ‘#pragma omp atomic’ on Linux

@sxjscience
Copy link
Member Author

@reminisce I think it's caused by mshadow::half::half_t, which is not supported by omp atomic.

@reminisce
Copy link
Contributor

I see. Is this a runtime error. If it's only float16 not supported, I suggest we'd better use omp atomic for all other types since float32 is the most common one.

@sxjscience
Copy link
Member Author

@piiswrong @reminisce Can it be merged?


assert (mx.nd.scatter_nd(data, idx, shape=(2, 2)).asnumpy() == [[0, 0], [2, 3]]).all()
assert (mx.nd.scatter_nd_acc(y, idx, shape=data.shape).asnumpy() == data.grad.asnumpy()).all()
for dtype in ['int32', 'int64', 'float16', 'float32', 'float64']:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that only int64 has been tested for scatter_nd_acc on the same index case. Could you confirm?

data_npy = np.random.randint(0, 10, (100,))
data = mx.nd.array(data_npy, dtype=dtype)
idx = mx.nd.zeros(shape=(1, 100), dtype='int32')
assert (mx.nd.scatter_nd_acc(data, idx, shape=(1,)).asscalar() == data_npy.sum())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reminisce I've added another test for all the dtypes.

@sxjscience
Copy link
Member Author

Should I merge it in?

@piiswrong
Copy link
Contributor

rename to _backward_gather_nd

@sxjscience
Copy link
Member Author

@piiswrong I've renamed accordingly.

@@ -510,6 +548,10 @@ The elements in output is defined as follows::

all other entries in output are 0.

WARNING!!! If the indices have duplicates, the result will be non-deterministic and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks ugly. Standard warning message is
.. Warning:: xxx

@piiswrong piiswrong merged commit d918868 into apache:master Jan 4, 2018
sxjscience added a commit that referenced this pull request Jan 4, 2018
* try to implement scatter_nd_acc

fix

fix

fix

update

only support real_type

update

update

try to fix

update

fix

update

revise test

fix lint

* fix

* mark line as no lint

* fix test

* revise test

* fix test case

* revise

* remove openmp

* update

* update

* update

* update test

* Revert "update test"

This reverts commit 3eb3ac6.

* Revert "update"

This reverts commit a28fa53.

* Revert "update"

This reverts commit e99ffd0.

* Revert "update"

This reverts commit 399ba02.

* add atomic and specialize the behavior of half_t

* use "!" instead of not

* add test

* fix test

* fix test

* fix test

* rename to backward_gather_nd

* fix

* fix

* fix doc
yuxiangw pushed a commit to yuxiangw/incubator-mxnet that referenced this pull request Jan 25, 2018
* try to implement scatter_nd_acc

fix

fix

fix

update

only support real_type

update

update

try to fix

update

fix

update

revise test

fix lint

* fix

* mark line as no lint

* fix test

* revise test

* fix test case

* revise

* remove openmp

* update

* update

* update

* update test

* Revert "update test"

This reverts commit 3eb3ac6.

* Revert "update"

This reverts commit a28fa53.

* Revert "update"

This reverts commit e99ffd0.

* Revert "update"

This reverts commit 399ba02.

* add atomic and specialize the behavior of half_t

* use "!" instead of not

* add test

* fix test

* fix test

* fix test

* rename to backward_gather_nd

* fix

* fix

* fix doc
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* try to implement scatter_nd_acc

fix

fix

fix

update

only support real_type

update

update

try to fix

update

fix

update

revise test

fix lint

* fix

* mark line as no lint

* fix test

* revise test

* fix test case

* revise

* remove openmp

* update

* update

* update

* update test

* Revert "update test"

This reverts commit 3eb3ac6.

* Revert "update"

This reverts commit a28fa53.

* Revert "update"

This reverts commit e99ffd0.

* Revert "update"

This reverts commit 399ba02.

* add atomic and specialize the behavior of half_t

* use "!" instead of not

* add test

* fix test

* fix test

* fix test

* rename to backward_gather_nd

* fix

* fix

* fix doc
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* try to implement scatter_nd_acc

fix

fix

fix

update

only support real_type

update

update

try to fix

update

fix

update

revise test

fix lint

* fix

* mark line as no lint

* fix test

* revise test

* fix test case

* revise

* remove openmp

* update

* update

* update

* update test

* Revert "update test"

This reverts commit 3eb3ac6.

* Revert "update"

This reverts commit a28fa53.

* Revert "update"

This reverts commit e99ffd0.

* Revert "update"

This reverts commit 399ba02.

* add atomic and specialize the behavior of half_t

* use "!" instead of not

* add test

* fix test

* fix test

* fix test

* rename to backward_gather_nd

* fix

* fix

* fix doc
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants