-
Notifications
You must be signed in to change notification settings - Fork 22
/
get_flops.py
executable file
·109 lines (89 loc) · 3.57 KB
/
get_flops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
from mmcv import Config
from mmcv.cnn import get_model_complexity_info
from mmseg.models import build_segmentor
import models
from fvcore.nn import FlopCountAnalysis
import torch
from numbers import Number
from typing import Any, Callable, List, Optional, Union
from numpy import prod
import numpy as np
from fvcore.nn import FlopCountAnalysis
from mmseg.datasets import build_dataset
def calc_flops(model, img_size=224):
with torch.no_grad():
x = torch.randn(1, 3, img_size, img_size).cuda()
fca1 = FlopCountAnalysis(model, x)
print('backbone:', fca1.total(module_name="backbone")/1e9)
try:
print('text_encoder:', fca1.total(module_name="text_encoder")/1e9)
print('context_decoder:', fca1.total(module_name="context_decoder")/1e9)
except:
pass
try:
print('neck:', fca1.total(module_name="neck")/1e9)
except:
pass
print('decode_head:', fca1.total(module_name="decode_head")/1e9)
flops1 = fca1.total()
print("#### GFLOPs: {:.1f}".format(flops1 / 1e9))
return flops1 / 1e9
def parse_args():
parser = argparse.ArgumentParser(description='Train a segmentor')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--fvcore',
action='store_true', default=False)
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[1024, 1024],
help='input image size')
args = parser.parse_args()
return args
def main():
args = parse_args()
if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
cfg = Config.fromfile(args.config)
cfg.model.pretrained = None
datasets = [build_dataset(cfg.data.train)]
if 'CLIP' in cfg.model.type:
cfg.model.class_names = list(datasets[0].CLASSES)
model = build_segmentor(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg')).cuda()
model.eval()
if hasattr(model, 'forward_dummy'):
model.forward = model.forward_dummy
else:
raise NotImplementedError(
'FLOPs counter is currently not currently supported with {}'.
format(model.__class__.__name__))
if args.fvcore:
flops = calc_flops(model, input_shape[1])
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
print('number of params:', f'{n_parameters:.1f}')
if hasattr(model, 'text_encoder'):
n_parameters_text = sum(p.numel() for p in model.text_encoder.parameters() if p.requires_grad) / 1e6
print('param without text encoder:', n_parameters-n_parameters_text)
if hasattr(model, 'context_decoder'):
n_parameters_text = sum(p.numel() for p in model.context_decoder.parameters() if p.requires_grad) / 1e6
print('param context:', n_parameters_text)
else:
flops, params = get_model_complexity_info(model, input_shape)
split_line = '=' * 30
print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
split_line, input_shape, flops, params))
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')
if __name__ == '__main__':
main()