From 0f76ef5cc236e6499ecaa66a126e2f40f70781ec Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Tue, 17 Sep 2019 12:59:28 +0200 Subject: [PATCH] Add list_ctx to ParameterDict Signed-off-by: Serge Panev --- python/mxnet/gluon/parameter.py | 8 ++++++++ tests/python/unittest/test_gluon.py | 11 +++++++++++ 2 files changed, 19 insertions(+) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 89a3c338173c..66f5c2afc8b0 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -902,6 +902,14 @@ def reset_ctx(self, ctx): for i in self.values(): i.reset_ctx(ctx) + def list_ctx(self): + """Returns a list of all the contexts on which the underlying Parameters + are initialized.""" + s = set() + for i in self.values(): + s.update(i.list_ctx()) + return list(s) + def setattr(self, name, value): """Set an attribute to a new value for all Parameters. diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index af30980b10ea..0d0eaf4a3ee8 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -135,6 +135,17 @@ def test_parameter_dict(): mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy()) mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy()) + # test reset_ctx + params3 = gluon.ParameterDict('net_') + params3.get('w0', shape=(10, 10)) + params3.get('w1', shape=(10, 10)) + params3.initialize(ctx) + list_contexts = [mx.cpu(42), mx.cpu(24)] + params3.reset_ctx(list_contexts) + # and test list_ctx + assert set(params3.list_ctx()) == set(list_contexts) + + # test the dtype casting functionality params0 = gluon.ParameterDict('') params0.get('w0', shape=(10, 10), dtype='float32')