Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Fix collect_params().zero_grad() in gluon numpy interface #16716

Merged
merged 4 commits into from
Nov 13, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,8 +903,13 @@ def zero_grad(self):
if len(arrays) == 0:
return

for arr in arrays.values():
mx.nd.reset_arrays(*arr, num_arrays=len(arr))
if is_np_array():
for arr in arrays.values():
for ele in arr:
ele[()] = 0
else:
for arr in arrays.values():
mx.nd.reset_arrays(*arr, num_arrays=len(arr))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not always use in-place assign?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m not sure why we used reset_arrays before. I guess that it would be faster if we use multiple arrays.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if that's the case, then we need its equivalence in npx namespace @reminisce

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an alias _npi_reset_arrays in reset_arrays.cc?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've checked the source code. The new approach should be fine as long as we use cudaMemsetAsync for implementing ele[()] = 0. In fact, reset_arrays.cc lies in the contrib folder and there is no need to add it to numpy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is worth having diverging implementation. If reset_arrays is not useful then we should stay away from it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we move away from reset_array in the old ndarary too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually do not know why we’ve used the reset_array. This op should be in the contrib while now it’s in the main API. I think this is somehow out-of-the-scope of this PR.


def reset_ctx(self, ctx):
"""Re-assign all Parameters to other contexts.
Expand Down
17 changes: 17 additions & 0 deletions tests/python/unittest/test_numpy_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,23 @@ def hybrid_forward(self, F, x, weight):
assert_almost_equal(out.asnumpy(), (x.asnumpy() + const_arr), atol=1e-5, rtol=1e-4, use_broadcast=False)


@use_np
def test_parameters_zero_grad():
for hybridize in [False, True]:
net = gluon.nn.HybridSequential()
for _ in range(5):
net.add(gluon.nn.Dense(10))
if hybridize:
net.hybridize()
net.initialize()
out = net(mx.np.ones((32, 8)))
for v in net.collect_params().values():
v.grad()[()] = 1
net.collect_params().zero_grad()
for v in net.collect_params().values():
assert_almost_equal(v.grad().asnumpy(), mx.np.zeros_like(v.grad()).asnumpy())


if __name__ == '__main__':
import nose
nose.runmodule()