From b843c3014f527ed19a0c769db75db780d97a27fe Mon Sep 17 00:00:00 2001 From: Zhijian Liu Date: Thu, 15 Apr 2021 14:57:35 -0400 Subject: [PATCH] Add `upsample_bilinear2d`, unify norms, and bump version to 0.0.3 --- setup.py | 1 + test.py | 22 ---------------------- torchprofile/handlers.py | 32 ++++++++++++++++++++------------ torchprofile/version.py | 2 +- 4 files changed, 22 insertions(+), 35 deletions(-) delete mode 100644 test.py diff --git a/setup.py b/setup.py index e760ddc..849c298 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ from setuptools import find_packages, setup + from torchprofile import __version__ setup( diff --git a/test.py b/test.py deleted file mode 100644 index 730f77b..0000000 --- a/test.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch -import torch.nn as nn -from torchprofile import profile_macs -from torchprofile.utils.trace import trace - - -class Model(nn.Module): - def forward(self, a, b): - return torch.matmul(a, b) - - -if __name__ == '__main__': - a = torch.zeros(10, 20, 1, 20, 20) - b = torch.zeros(20, 30) - - rnn = nn.LSTM(10, 20, 2) - input = torch.randn(5, 3, 10) - h0 = torch.randn(2, 3, 20) - c0 = torch.randn(2, 3, 20) - output, (hn, cn) = rnn(input, (h0, c0)) - print(trace(rnn, (input, (h0, c0)))) - print(profile_macs(rnn, (input, (h0, c0)))) diff --git a/torchprofile/handlers.py b/torchprofile/handlers.py index 7c1bd0a..e0a4837 100644 --- a/torchprofile/handlers.py +++ b/torchprofile/handlers.py @@ -71,14 +71,16 @@ def convolution(node): return math.prod(os) * ic * math.prod(ks) -def batch_norm(node): - # TODO: provide an option to not fuse `batch_norm` into `linear` or `conv` - return 0 - +def norm(node): + if node.operator in ['aten::batch_norm', 'aten::instance_norm']: + affine = node.inputs[1].shape is not None + elif node.operator in ['aten::layer_norm', 'aten::group_norm']: + affine = node.inputs[2].shape is not None + else: + raise ValueError(node.operator) -def instance_norm_or_layer_norm(node): os = node.outputs[0].shape - return math.prod(os) + return math.prod(os) if affine else 0 def avg_pool_or_mean(node): @@ -91,6 +93,11 @@ def leaky_relu(node): return math.prod(os) +def upsample_bilinear2d(node): + os = node.outputs[0].shape + return math.prod(os) * 4 + + handlers = ( ('aten::addmm', addmm), ('aten::addmv', addmv), @@ -98,22 +105,23 @@ def leaky_relu(node): ('aten::matmul', matmul), (('aten::mul', 'aten::mul_'), mul), ('aten::_convolution', convolution), - ('aten::batch_norm', batch_norm), - (('aten::instance_norm', 'aten::layer_norm'), instance_norm_or_layer_norm), + (('aten::batch_norm', 'aten::instance_norm', 'aten::layer_norm', + 'aten::group_norm'), norm), (('aten::adaptive_avg_pool1d', 'aten::adaptive_avg_pool2d', 'aten::adaptive_avg_pool3d', 'aten::avg_pool1d', 'aten::avg_pool2d', 'aten::avg_pool3d', 'aten::mean'), avg_pool_or_mean), ('aten::leaky_relu', leaky_relu), + ('aten::upsample_bilinear2d', upsample_bilinear2d), (('aten::adaptive_max_pool1d', 'aten::adaptive_max_pool2d', 'aten::adaptive_max_pool3d', 'aten::add', 'aten::add_', 'aten::alpha_dropout', 'aten::cat', 'aten::chunk', 'aten::clamp', 'aten::clone', 'aten::constant_pad_nd', 'aten::contiguous', 'aten::detach', 'aten::div', 'aten::div_', 'aten::dropout', 'aten::dropout_', 'aten::embedding', 'aten::eq', 'aten::feature_dropout', - 'aten::flatten', 'aten::floor', 'aten::gt', 'aten::hardtanh_', - 'aten::index', 'aten::int', 'aten::log_softmax', 'aten::lt', - 'aten::max_pool1d', 'aten::max_pool1d_with_indices', 'aten::max_pool2d', - 'aten::max_pool2d_with_indices', 'aten::max_pool3d', + 'aten::flatten', 'aten::floor', 'aten::floor_divide', 'aten::gt', + 'aten::hardtanh_', 'aten::index', 'aten::int', 'aten::log_softmax', + 'aten::lt', 'aten::max_pool1d', 'aten::max_pool1d_with_indices', + 'aten::max_pool2d', 'aten::max_pool2d_with_indices', 'aten::max_pool3d', 'aten::max_pool3d_with_indices', 'aten::max_unpool1d', 'aten::max_unpool2d', 'aten::max_unpool3d', 'aten::ne', 'aten::reflection_pad1d', 'aten::reflection_pad2d', diff --git a/torchprofile/version.py b/torchprofile/version.py index d18f409..ffcc925 100644 --- a/torchprofile/version.py +++ b/torchprofile/version.py @@ -1 +1 @@ -__version__ = '0.0.2' +__version__ = '0.0.3'