From 2a47afb5f22a1b1c4c17baf660e235c6e3105b49 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Tue, 15 Oct 2019 21:05:11 +0000 Subject: [PATCH] fix optimizer bug --- python/mxnet/optimizer/optimizer.py | 2 +- tests/python/unittest/test_numpy_gluon.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index c3a1f3374a94..bc87777e40fc 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -1942,7 +1942,7 @@ def __init__(self, optimizer): def __call__(self, index, grad, weight): """Updates weight given gradient and index.""" - allow_np = self.optimizer.allow_np_array + allow_np = self.optimizer.allow_np_array if hasattr(self.optimizer, "allow_np_array") else is_np_array() if not isinstance(index, (list, tuple)): indices = [index] grads = [_as_classic(grad, allow_np)] diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py index f24eb6a325bf..af5425336699 100644 --- a/tests/python/unittest/test_numpy_gluon.py +++ b/tests/python/unittest/test_numpy_gluon.py @@ -113,6 +113,15 @@ def hybrid_forward(self, F, pred, label): trainer.step(1) +@with_seed() +@use_np +def test_optimizer_backward_compat(): + optimizer = mx.optimizer.SGD() + delattr(optimizer, "allow_np_array") + updater = mx.optimizer.Updater(optimizer) + updater(0, np.ones((0, 0)), np.zeros((0, 0))) + + @with_seed() @use_np def test_np_loss_ndarray():