Skip to content

Commit

Permalink
raise warning when we detect Block inside nested list/dict (apache#9148)
Browse files Browse the repository at this point in the history
* add warning for block in nested list dict

Block inside contiainers is not supported

* test ci again
  • Loading branch information
sxjscience authored and zheng-da committed Jun 28, 2018
1 parent 09495d8 commit db98d52
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 1 deletion.
29 changes: 29 additions & 0 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,33 @@ def __setattr__(self, name, value):

super(Block, self).__setattr__(name, value)

def _check_container_with_block(self):
def _find_block_in_container(data):
# Find whether a nested container structure contains Blocks
if isinstance(data, (list, tuple)):
for ele in data:
if _find_block_in_container(ele):
return True
return False
elif isinstance(data, dict):
for _, v in data.items():
if _find_block_in_container(v):
return True
return False
elif isinstance(data, Block):
return True
else:
return False
for k, v in self.__dict__.items():
if isinstance(v, (list, tuple, dict)) and not (k.startswith('__') or k == '_children'):
if _find_block_in_container(v):
warnings.warn('"{name}" is a container with Blocks. '
'Note that Blocks inside the list, tuple or dict will not be '
'registered automatically. Make sure to register them using '
'register_child() or switching to '
'nn.Sequential/nn.HybridSequential instead. '
.format(name=self.__class__.__name__ + "." + k))

def _alias(self):
return self.__class__.__name__.lower()

Expand Down Expand Up @@ -252,6 +279,8 @@ def collect_params(self, select=None):
-------
The selected :py:class:`ParameterDict`
"""
# We need to check here because blocks inside containers are not supported.
self._check_container_with_block()
ret = ParameterDict(self._params.prefix)
if not select:
ret.update(self.params)
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class HybridSequential(HybridBlock):
Example::
net = nn.Sequential()
net = nn.HybridSequential()
# use net's name_scope to give child Blocks appropriate names.
with net.name_scope():
net.add(nn.Dense(10, activation='relu'))
Expand Down
44 changes: 44 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,50 @@ def test_block_attr_regular():
b.c = c2
assert b.c is c2 and b._children[0] is c2

def test_block_attr_list_of_block():
class Model1(gluon.Block):
def __init__(self, **kwargs):
super(Model1, self).__init__(**kwargs)
with self.name_scope():
self.layers = [nn.Dense(i * 10) for i in range(6)]

class Model2(gluon.Block):
def __init__(self, **kwargs):
super(Model2, self).__init__(**kwargs)
with self.name_scope():
self.layers = dict()
self.layers['a'] = [nn.Dense(10), nn.Dense(10)]

class Model3(gluon.Block):
def __init__(self, **kwargs):
super(Model3, self).__init__(**kwargs)
with self.name_scope():
self.layers = nn.Sequential()
self.layers.add(*[nn.Dense(i * 10) for i in range(6)])

class Model4(gluon.Block):
def __init__(self, **kwargs):
super(Model4, self).__init__(**kwargs)
with self.name_scope():
self.data = {'a': '4', 'b': 123}

with warnings.catch_warnings(record=True) as w:
model = Model1()
model.collect_params()
assert len(w) > 0
with warnings.catch_warnings(record=True) as w:
model = Model2()
model.collect_params()
assert len(w) > 0
with warnings.catch_warnings(record=True) as w:
model = Model3()
model.collect_params()
assert len(w) == 0
with warnings.catch_warnings(record=True) as w:
model = Model4()
model.collect_params()
assert len(w) == 0

def test_sequential_warning():
with warnings.catch_warnings(record=True) as w:
b = gluon.nn.Sequential()
Expand Down

0 comments on commit db98d52

Please sign in to comment.