diff --git a/mmpose/models/pose_estimators/base.py b/mmpose/models/pose_estimators/base.py index b97232b344..057f0cf9e6 100644 --- a/mmpose/models/pose_estimators/base.py +++ b/mmpose/models/pose_estimators/base.py @@ -90,7 +90,7 @@ def _load_metainfo(metainfo: dict = None) -> dict: def forward(self, inputs: torch.Tensor, - data_samples: OptSampleList, + data_samples: OptSampleList = None, mode: str = 'tensor') -> ForwardResults: """The unified entry for a forward process in both training and test. diff --git a/tools/analysis_tools/get_flops.py b/tools/analysis_tools/get_flops.py index b62a320909..3c2c5c47fc 100644 --- a/tools/analysis_tools/get_flops.py +++ b/tools/analysis_tools/get_flops.py @@ -1,83 +1,85 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse -import torch +from mmengine.config import Config, DictAction + +from mmpose.registry import MODELS +from mmpose.utils import register_all_modules try: - from fvcore.nn import (ActivationCountAnalysis, FlopCountAnalysis, - flop_count_str, flop_count_table, parameter_count) + from mmengine.analysis import get_model_complexity_info except ImportError: - print('You may need to install fvcore for flops computation, ' - 'and you can use `pip install fvcore` to set up the environment') -from fvcore.nn.print_model_statistics import _format_size -from mmengine import Config - -from mmpose.models import build_pose_estimator -from mmpose.utils import register_all_modules + raise ImportError('Please upgrade mmcv to >0.6.2') def parse_args(): - parser = argparse.ArgumentParser(description='Get model flops and params') - parser.add_argument('config', help='config file path') + parser = argparse.ArgumentParser(description='Train a detector') + parser.add_argument('config', help='train config file path') parser.add_argument( '--shape', type=int, nargs='+', - default=[256, 192], + default=[1280, 800], help='input image size') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') args = parser.parse_args() return args def main(): - + register_all_modules() args = parse_args() if len(args.shape) == 1: - input_shape = (3, args.shape[0], args.shape[0]) + h = w = args.shape[0] elif len(args.shape) == 2: - input_shape = (3, ) + tuple(args.shape) + h, w = args.shape else: raise ValueError('invalid input shape') + input_shape = (3, h, w) cfg = Config.fromfile(args.config) - model = build_pose_estimator(cfg.model) - model.eval() - - if hasattr(model, 'extract_feat'): - model.forward = model.extract_feat - else: - raise NotImplementedError( - 'FLOPs counter is currently not currently supported with {}'. - format(model.__class__.__name__)) - - inputs = (torch.randn((1, *input_shape)), ) - flops_ = FlopCountAnalysis(model, inputs) - activations_ = ActivationCountAnalysis(model, inputs) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) - flops = _format_size(flops_.total()) - activations = _format_size(activations_.total()) - params = _format_size(parameter_count(model)['']) + model = MODELS.build(cfg.model) + model.eval() - flop_table = flop_count_table( - flops=flops_, - activations=activations_, - show_param_shapes=True, - ) - flop_str = flop_count_str(flops=flops_, activations=activations_) + analysis_results = get_model_complexity_info( + model, input_shape, show_table=True, show_arch=False) - print('\n' + flop_str) - print('\n' + flop_table) + # ayalysis_results = { + # 'flops': flops, + # 'flops_str': flops_str, + # 'activations': activations, + # 'activations_str': activations_str, + # 'params': params, + # 'params_str': params_str, + # 'out_table': complexity_table, + # 'out_arch': complexity_arch + # } split_line = '=' * 30 print(f'{split_line}\nInput shape: {input_shape}\n' - f'Flops: {flops}\nParams: {params}\n' - f'Activation: {activations}\n{split_line}') + f'Flops: {analysis_results["flops"]}\n' + f'Params: {analysis_results["params"]}\n{split_line}') + + print(analysis_results['activations']) + # print(analysis_results['complexity_table']) + # print(complexity_str) print('!!!Please be cautious if you use the results in papers. ' 'You may need to check if all ops are supported and verify that the ' 'flops computation is correct.') if __name__ == '__main__': - register_all_modules() main() diff --git a/tools/analysis_tools/get_flops1.py b/tools/analysis_tools/get_flops1.py new file mode 100644 index 0000000000..ab603e1e3b --- /dev/null +++ b/tools/analysis_tools/get_flops1.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import torch +from mmengine.config import DictAction + +from mmpose.apis.inference import init_model + +try: + # from mmcv.cnn import get_model_complexity_info + from mmengine.analysis import get_model_complexity_info +except ImportError: + raise ImportError('Please upgrade mmcv to >0.6.2') + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a recognizer') + parser.add_argument('config', help='train config file path') + parser.add_argument( + '--device', default='cpu', help='Device used for model initialization') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + default={}, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. For example, ' + "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[256, 192], + help='input image size') + parser.add_argument( + '--input-constructor', + '-c', + type=str, + choices=['none', 'batch'], + default='none', + help='If specified, it takes a callable method that generates ' + 'input. Otherwise, it will generate a random tensor with ' + 'input shape to calculate FLOPs.') + parser.add_argument( + '--batch-size', '-b', type=int, default=1, help='input batch size') + parser.add_argument( + '--not-print-per-layer-stat', + '-n', + action='store_true', + help='Whether to print complexity information' + 'for each layer in a model') + args = parser.parse_args() + return args + + +def batch_constructor(flops_model, batch_size, input_shape): + """Generate a batch of tensors to the model.""" + batch = {} + + inputs = torch.ones(()).new_empty( + (batch_size, *input_shape), + dtype=next(flops_model.parameters()).dtype, + device=next(flops_model.parameters()).device) + + batch['inputs'] = inputs + return batch + + +def main(): + + args = parse_args() + + if len(args.shape) == 1: + input_shape = (3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = (3, ) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + + model = init_model( + args.config, + checkpoint=None, + device=args.device, + cfg_options=args.cfg_options) + + if hasattr(model, '_forward'): + model.forward = model._forward + else: + raise NotImplementedError( + 'FLOPs counter is currently not currently supported with {}'. + format(model.__class__.__name__)) + + analysis_results = get_model_complexity_info(model, input_shape) + flops = analysis_results['flops_str'] + params = analysis_results['params_str'] + split_line = '=' * 30 + input_shape = (args.batch_size, ) + input_shape + print(f'{split_line}\nInput shape: {input_shape}\n' + f'Flops: {flops}\nParams: {params}\n{split_line}') + print('!!!Please be cautious if you use the results in papers. ' + 'You may need to check if all ops are supported and verify that the ' + 'flops computation is correct.') + + +if __name__ == '__main__': + main()