Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add gluon model summary (#10989)
Browse files Browse the repository at this point in the history
* add hook api

* add block.summary

* remove count
  • Loading branch information
szha authored and piiswrong committed May 22, 2018
1 parent 0185f87 commit 022f238
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 3 deletions.
172 changes: 169 additions & 3 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ..ndarray import NDArray
from .. import name as _name
from .parameter import Parameter, ParameterDict, DeferredInitializationError
from .utils import _indent, _brief_print_list
from .utils import _indent, _brief_print_list, HookHandle


class _BlockScope(object):
Expand Down Expand Up @@ -173,6 +173,8 @@ def __init__(self, prefix=None, params=None):
self._scope = _BlockScope(self)
self._children = OrderedDict()
self._reg_params = {}
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()

def __repr__(self):
s = '{name}(\n{modstr}\n)'
Expand Down Expand Up @@ -355,14 +357,68 @@ def load_params(self, filename, ctx=None, allow_missing=False,
name, filename, _brief_print_list(self._params.keys())))
params[name]._load_init(loaded[name], ctx)


def register_child(self, block, name=None):
"""Registers block as a child of self. :py:class:`Block` s assigned to self as
attributes will be registered automatically."""
if name is None:
name = str(len(self._children))
self._children[name] = block

def register_forward_pre_hook(self, hook):
r"""Registers a forward pre-hook on the block.
The hook function is called immediately before :func:`forward`.
It should not modify the input or output.
Parameters
----------
hook : callable
The forward hook function of form `hook(block, input) -> None`.
Returns
-------
:class:`mxnet.gluon.utils.HookHandle`
"""
handle = HookHandle()
handle.attach(self._forward_pre_hooks, hook)
return handle

def register_forward_hook(self, hook):
r"""Registers a forward hook on the block.
The hook function is called immediately after :func:`forward`.
It should not modify the input or output.
Parameters
----------
hook : callable
The forward hook function of form `hook(block, input, output) -> None`.
Returns
-------
:class:`mxnet.gluon.utils.HookHandle`
"""
handle = HookHandle()
handle.attach(self._forward_hooks, hook)
return handle

def apply(self, fn):
r"""Applies ``fn`` recursively to every child block as well as self.
Parameters
----------
fn : callable
Function to be applied to each submodule, of form `fn(block)`.
Returns
-------
this block
"""
for cld in self._children.values():
cld.apply(fn)
fn(self)
return self

def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,
force_reinit=False):
"""Initializes :py:class:`Parameter` s of this :py:class:`Block` and its children.
Expand Down Expand Up @@ -411,7 +467,15 @@ def cast(self, dtype):

def __call__(self, *args):
"""Calls forward. Only accepts positional arguments."""
return self.forward(*args)
for hook in self._forward_pre_hooks.values():
hook(self, args)

out = self.forward(*args)

for hook in self._forward_hooks.values():
hook(self, args, out)

return out

def forward(self, *args):
"""Overrides to implement forward computation using :py:class:`NDArray`. Only
Expand All @@ -425,6 +489,105 @@ def forward(self, *args):
# pylint: disable= invalid-name
raise NotImplementedError

def summary(self, *inputs):
"""Print the summary of the model's output and parameters.
The network must have been initialized, and must not have been hybridized.
Parameters
----------
inputs : object
Any input that the model supports. For any tensor in the input, only
:class:`mxnet.ndarray.NDArray` is supported.
"""
summary = OrderedDict()
hooks = []

def _get_shape_str(args):
def flatten(args):
if not isinstance(args, (list, tuple)):
return [args], int(0)
flat = []
fmts = []
for i in args:
arg, fmt = flatten(i)
flat.extend(arg)
fmts.append(fmt)
return flat, fmts

def regroup(args, fmt):
if isinstance(fmt, int):
if fmt == 0:
return args[0], args[1:]
return args[:fmt], args[fmt:]
ret = []
for i in fmt:
res, args = regroup(args, i)
ret.append(res)
return ret, args

flat_args, fmts = flatten(args)
flat_arg_shapes = [x.shape if isinstance(x, ndarray.NDArray) else x
for x in flat_args]
shapes = regroup(flat_arg_shapes, fmts)[0]
if isinstance(shapes, list):
shape_str = str(shapes)[1:-1]
else:
shape_str = str(shapes)
return shape_str.replace('L', '')

def _register_summary_hook(block):
assert not isinstance(block, HybridBlock) or not block._active, \
'"{}" must not be hybridized to print summary.'.format(block.name)
def _summary_hook(block, _, outputs):
class_name = block.__class__.__name__
block_idx = len(summary) - 1

m_key = '%s-%i' % (class_name, block_idx+1)
summary[m_key] = OrderedDict()
summary[m_key]['output_shape'] = _get_shape_str(outputs)

params = 0
summary[m_key]['trainable'] = 0
for p in block._reg_params.values():
params += p.data().size
summary[m_key]['trainable'] += 0 if p.grad_req == 'null' else p.data().size
summary[m_key]['n_params'] = params

from .nn.basic_layers import Sequential, HybridSequential
if not isinstance(block, (Sequential, HybridSequential)):
hooks.append(block.register_forward_hook(_summary_hook))

summary['Input'] = OrderedDict()
summary['Input']['output_shape'] = _get_shape_str(inputs)
summary['Input']['n_params'] = 0
summary['Input']['trainable'] = 0

try:
self.apply(_register_summary_hook)
self(*inputs)

line_format = '{:>20} {:>42} {:>15}'
print('-'*80)
print(line_format.format('Layer (type)', 'Output Shape', 'Param #'))
print('='*80)
total_params = 0
trainable_params = 0
for layer in summary:
print(line_format.format(layer,
str(summary[layer]['output_shape']),
summary[layer]['n_params']))
total_params += summary[layer]['n_params']
trainable_params += summary[layer]['trainable']
print('='*80)
print('Total params: ' + str(total_params))
print('Trainable params: ' + str(trainable_params))
print('Non-trainable params: ' + str(total_params - trainable_params))
print('-'*80)
finally:
for h in hooks:
h.detach()


class HybridBlock(Block):
"""`HybridBlock` supports forwarding with both Symbol and NDArray.
Expand Down Expand Up @@ -549,6 +712,9 @@ def hybridize(self, active=True, **kwargs):
self._active = active
self._flags = kwargs.items()
self._clear_cached_op()
if active and self._forward_hooks or self._forward_pre_hooks:
warnings.warn('"{}" is being hybridized while still having forward hook/pre-hook. '
'If "{}" is a child of HybridBlock, the hooks will not take effect.')
super(HybridBlock, self).hybridize(active, **kwargs)

def cast(self, dtype):
Expand Down
4 changes: 4 additions & 0 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def __init__(self, hidden_size, num_layers, layout,
allow_deferred_init=True))
ni = nh * self._dir

for param_list in [self.i2h_weight, self.h2h_weight, self.i2h_bias, self.h2h_bias]:
for p in param_list:
self._reg_params[p.name] = p

self._unfused = self._unfuse()

def __repr__(self):
Expand Down
37 changes: 37 additions & 0 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import os
import hashlib
import warnings
import collections
import weakref
try:
import requests
except ImportError:
Expand Down Expand Up @@ -250,3 +252,38 @@ def _brief_print_list(lst, limit=7):
return _brief_print_list(lst[:limit//2], limit) + ', ..., ' + \
_brief_print_list(lst[-limit//2:], limit)
return ', '.join(["'%s'"%str(i) for i in lst])


class HookHandle(object):
"""A handle that can attach/detach a hook."""

def __init__(self):
self._hooks_dict_ref = None
self._id = None

def attach(self, hooks_dict, hook):
assert not self._hooks_dict_ref, 'The same handle cannot be attached twice.'
self._id = id(hook)
hooks_dict[self._id] = hook
self._hooks_dict_ref = weakref.ref(hooks_dict)

def detach(self):
hooks_dict = self._hooks_dict_ref()
if hooks_dict is not None and self._id in hooks_dict:
del hooks_dict[self._id]

def __getstate__(self):
return (self._hooks_dict_ref(), self._id)

def __setstate__(self, state):
if state[0] is None:
self._hooks_dict_ref = weakref.ref(collections.OrderedDict())
else:
self._hooks_dict_ref = weakref.ref(state[0])
self._id = state[1]

def __enter__(self):
return self

def __exit__(self, ptype, value, trace):
self.detach()
77 changes: 77 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,83 @@ def test_zero_grad():
assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)


@with_seed()
def test_hook():
global hook_call_count
hook_call_count = 0
global pre_hook_call_count
pre_hook_call_count = 0

def call_hook(block, x, y):
global hook_call_count
hook_call_count += 1

def call_pre_hook(block, x):
global pre_hook_call_count
pre_hook_call_count += 1

block = nn.Dense(10)
block.initialize()
handle = block.register_forward_hook(call_hook)
pre_handle = block.register_forward_pre_hook(call_pre_hook)
block(mx.nd.ones((3, 5)))

assert hook_call_count == 1
assert pre_hook_call_count == 1

handle.detach()
block(mx.nd.ones((3, 5)))

assert hook_call_count == 1
assert pre_hook_call_count == 2

pre_handle.detach()
block(mx.nd.ones((3, 5)))
assert hook_call_count == 1
assert pre_hook_call_count == 2


@with_seed()
def test_apply():
global called_blocks
called_blocks = []

def record_name(block):
global called_blocks
called_blocks.append(block.name)

block = nn.HybridSequential(prefix='test_')
with block.name_scope():
block.add(nn.Dense(10))
block.add(nn.Dropout(0.5))
block.apply(record_name)

assert called_blocks == ['test_dense0', 'test_dropout0', 'test']


@with_seed()
def test_summary():
net = gluon.model_zoo.vision.resnet50_v1()
net.initialize()
net.summary(mx.nd.ones((32, 3, 224, 224)))

net2 = nn.Sequential()
with net2.name_scope():
net2.add(nn.Embedding(10, 20))
net2.add(gluon.rnn.LSTM(30))
net2.add(nn.Dense(40, flatten=False))
net2.initialize()
net2.summary(mx.nd.ones((80, 32)))

net3 = gluon.rnn.LSTM(30)
net3.initialize()
begin_state = net3.begin_state(32)
net3.summary(mx.nd.ones((80, 32, 5)), begin_state)

net.hybridize()
assert_raises(AssertionError, net.summary, mx.nd.ones((32, 3, 224, 224)))


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 022f238

Please sign in to comment.