From f4cc5338a1ddfcc785ef5058eadb7719f4ed4eea Mon Sep 17 00:00:00 2001 From: root Date: Fri, 12 Feb 2021 23:29:36 +0000 Subject: [PATCH] Use multi-tensor zeroing for resetting grads --- python/mxnet/gluon/block.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 547fbaa8a6c2..299df1843b53 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -742,18 +742,16 @@ def zero_grad(self): if g.stype == 'row_sparse': ndarray.zeros_like(g, out=g) else: - arrays[g.ctx].append(g) + if is_np_array(): + arrays[g.ctx].append(g.as_nd_ndarray()) + else: + arrays[g.ctx].append(g) if len(arrays) == 0: return - if is_np_array(): - for arr in arrays.values(): - for ele in arr: - ele[()] = 0 - else: - for arr in arrays.values(): - ndarray.reset_arrays(*arr, num_arrays=len(arr)) + for arr in arrays.values(): + ndarray.reset_arrays(*arr, num_arrays=len(arr)) def reset_ctx(self, ctx): """Re-assign all Parameters to other contexts.