Skip to content

Commit

Permalink
use mmengine
Browse files Browse the repository at this point in the history
  • Loading branch information
ly015 committed Mar 13, 2023
1 parent 30154f0 commit 3b0e12d
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 44 deletions.
2 changes: 1 addition & 1 deletion mmpose/models/pose_estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _load_metainfo(metainfo: dict = None) -> dict:

def forward(self,
inputs: torch.Tensor,
data_samples: OptSampleList,
data_samples: OptSampleList = None,
mode: str = 'tensor') -> ForwardResults:
"""The unified entry for a forward process in both training and test.
Expand Down
88 changes: 45 additions & 43 deletions tools/analysis_tools/get_flops.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,85 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse

import torch
from mmengine.config import Config, DictAction

from mmpose.registry import MODELS
from mmpose.utils import register_all_modules

try:
from fvcore.nn import (ActivationCountAnalysis, FlopCountAnalysis,
flop_count_str, flop_count_table, parameter_count)
from mmengine.analysis import get_model_complexity_info
except ImportError:
print('You may need to install fvcore for flops computation, '
'and you can use `pip install fvcore` to set up the environment')
from fvcore.nn.print_model_statistics import _format_size
from mmengine import Config

from mmpose.models import build_pose_estimator
from mmpose.utils import register_all_modules
raise ImportError('Please upgrade mmcv to >0.6.2')


def parse_args():
parser = argparse.ArgumentParser(description='Get model flops and params')
parser.add_argument('config', help='config file path')
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[256, 192],
default=[1280, 800],
help='input image size')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args


def main():

register_all_modules()
args = parse_args()

if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0])
h = w = args.shape[0]
elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape)
h, w = args.shape
else:
raise ValueError('invalid input shape')
input_shape = (3, h, w)

cfg = Config.fromfile(args.config)
model = build_pose_estimator(cfg.model)
model.eval()

if hasattr(model, 'extract_feat'):
model.forward = model.extract_feat
else:
raise NotImplementedError(
'FLOPs counter is currently not currently supported with {}'.
format(model.__class__.__name__))

inputs = (torch.randn((1, *input_shape)), )
flops_ = FlopCountAnalysis(model, inputs)
activations_ = ActivationCountAnalysis(model, inputs)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

flops = _format_size(flops_.total())
activations = _format_size(activations_.total())
params = _format_size(parameter_count(model)[''])
model = MODELS.build(cfg.model)
model.eval()

flop_table = flop_count_table(
flops=flops_,
activations=activations_,
show_param_shapes=True,
)
flop_str = flop_count_str(flops=flops_, activations=activations_)
analysis_results = get_model_complexity_info(
model, input_shape, show_table=True, show_arch=False)

print('\n' + flop_str)
print('\n' + flop_table)
# ayalysis_results = {
# 'flops': flops,
# 'flops_str': flops_str,
# 'activations': activations,
# 'activations_str': activations_str,
# 'params': params,
# 'params_str': params_str,
# 'out_table': complexity_table,
# 'out_arch': complexity_arch
# }

split_line = '=' * 30
print(f'{split_line}\nInput shape: {input_shape}\n'
f'Flops: {flops}\nParams: {params}\n'
f'Activation: {activations}\n{split_line}')
f'Flops: {analysis_results["flops"]}\n'
f'Params: {analysis_results["params"]}\n{split_line}')

print(analysis_results['activations'])
# print(analysis_results['complexity_table'])
# print(complexity_str)
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')


if __name__ == '__main__':
register_all_modules()
main()
106 changes: 106 additions & 0 deletions tools/analysis_tools/get_flops1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse

import torch
from mmengine.config import DictAction

from mmpose.apis.inference import init_model

try:
# from mmcv.cnn import get_model_complexity_info
from mmengine.analysis import get_model_complexity_info
except ImportError:
raise ImportError('Please upgrade mmcv to >0.6.2')


def parse_args():
parser = argparse.ArgumentParser(description='Train a recognizer')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--device', default='cpu', help='Device used for model initialization')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
default={},
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. For example, '
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[256, 192],
help='input image size')
parser.add_argument(
'--input-constructor',
'-c',
type=str,
choices=['none', 'batch'],
default='none',
help='If specified, it takes a callable method that generates '
'input. Otherwise, it will generate a random tensor with '
'input shape to calculate FLOPs.')
parser.add_argument(
'--batch-size', '-b', type=int, default=1, help='input batch size')
parser.add_argument(
'--not-print-per-layer-stat',
'-n',
action='store_true',
help='Whether to print complexity information'
'for each layer in a model')
args = parser.parse_args()
return args


def batch_constructor(flops_model, batch_size, input_shape):
"""Generate a batch of tensors to the model."""
batch = {}

inputs = torch.ones(()).new_empty(
(batch_size, *input_shape),
dtype=next(flops_model.parameters()).dtype,
device=next(flops_model.parameters()).device)

batch['inputs'] = inputs
return batch


def main():

args = parse_args()

if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape)
else:
raise ValueError('invalid input shape')

model = init_model(
args.config,
checkpoint=None,
device=args.device,
cfg_options=args.cfg_options)

if hasattr(model, '_forward'):
model.forward = model._forward
else:
raise NotImplementedError(
'FLOPs counter is currently not currently supported with {}'.
format(model.__class__.__name__))

analysis_results = get_model_complexity_info(model, input_shape)
flops = analysis_results['flops_str']
params = analysis_results['params_str']
split_line = '=' * 30
input_shape = (args.batch_size, ) + input_shape
print(f'{split_line}\nInput shape: {input_shape}\n'
f'Flops: {flops}\nParams: {params}\n{split_line}')
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')


if __name__ == '__main__':
main()

0 comments on commit 3b0e12d

Please sign in to comment.