-
Notifications
You must be signed in to change notification settings - Fork 2
/
get_flops.py
181 lines (156 loc) · 6.29 KB
/
get_flops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import tempfile
from functools import partial
from pathlib import Path
import torch
from mmengine.config import Config, DictAction
from mmengine.logging import MMLogger
from mmengine.model import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.runner import Runner
from mmdet.registry import MODELS
try:
from mmengine.analysis import get_model_complexity_info
from mmengine.analysis.print_helper import _format_size
except ImportError:
raise ImportError("Please upgrade mmengine >= 0.6.0")
from models import *
from torchprofile import profile_macs
def parse_args():
parser = argparse.ArgumentParser(description="Get a detector flops")
parser.add_argument("config", help="train config file path")
parser.add_argument(
"--shape", type=int, nargs="+", default=[1280, 800], help="input image size"
)
parser.add_argument(
"--cfg-options",
nargs="+",
action=DictAction,
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file. If the value to "
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
"Note that the quotation marks are necessary and that no white space "
"is allowed.",
)
args = parser.parse_args()
return args
def inference(args, logger):
if str(torch.__version__) < "1.12":
logger.warning(
"Some config files, such as configs/yolact and configs/detectors,"
"may have compatibility issues with torch.jit when torch<1.12. "
"If you want to calculate flops for these models, "
"please make sure your pytorch version is >=1.12."
)
config_name = Path(args.config)
if not config_name.exists():
logger.error(f"{config_name} not found.")
cfg = Config.fromfile(args.config)
cfg.work_dir = tempfile.TemporaryDirectory().name
cfg.log_level = "WARN"
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
init_default_scope(cfg.get("default_scope", "mmdet"))
# TODO: The following usage is temporary and not safe
# use hard code to convert mmSyncBN to SyncBN. This is a known
# bug in mmengine, mmSyncBN requires a distributed environment,
# this question involves models like configs/strong_baselines
if hasattr(cfg, "head_norm_cfg"):
cfg["head_norm_cfg"] = dict(type="SyncBN", requires_grad=True)
cfg["model"]["roi_head"]["bbox_head"]["norm_cfg"] = dict(
type="SyncBN", requires_grad=True
)
cfg["model"]["roi_head"]["mask_head"]["norm_cfg"] = dict(
type="SyncBN", requires_grad=True
)
if len(args.shape) == 1:
h = w = args.shape[0]
elif len(args.shape) == 2:
h, w = args.shape
else:
raise ValueError("invalid input shape")
result = {}
# Supports two ways to calculate flops,
# 1. randomly generate a picture
# 2. load a picture from the dataset
# In two stage detectors, _forward need batch_samples to get
# rpn_results_list, then use rpn_results_list to compute flops,
# so only the second way is supported
try:
model = MODELS.build(cfg.model)
if torch.cuda.is_available():
model.cuda()
model = revert_sync_batchnorm(model)
data_batch = {"inputs": [torch.rand(3, h, w)], "batch_samples": [None]}
data = model.data_preprocessor(data_batch)
result["ori_shape"] = (h, w)
result["pad_shape"] = data["inputs"].shape[-2:]
model.eval()
outputs = get_model_complexity_info(
model, None, inputs=data["inputs"], show_table=False, show_arch=False
)
flops = outputs["flops"]
params = outputs["params"]
result["compute_type"] = "direct: randomly generate a picture"
# torchprofile
tp_flops = profile_macs(model, data["inputs"])
except TypeError:
logger.warning("Failed to directly get FLOPs, try to get flops with real data")
data_loader = Runner.build_dataloader(cfg.val_dataloader)
data_batch = next(iter(data_loader))
model = MODELS.build(cfg.model)
if torch.cuda.is_available():
model = model.cuda()
model = revert_sync_batchnorm(model)
model.eval()
_forward = model.forward
data = model.data_preprocessor(data_batch)
result["ori_shape"] = data["data_samples"][0].ori_shape
result["pad_shape"] = data["data_samples"][0].pad_shape
del data_loader
model.forward = partial(_forward, data_samples=data["data_samples"])
outputs = get_model_complexity_info(
model, None, inputs=data["inputs"], show_table=False, show_arch=False
)
flops = outputs["flops"]
params = outputs["params"]
result["compute_type"] = "dataloader: load a picture from the dataset"
# torchprofile
tp_flops = profile_macs(model, data["inputs"])
flops = _format_size(flops)
tp_flops = _format_size(tp_flops)
params = _format_size(params)
result["flops"] = flops
result["torchprofile_flops"] = tp_flops
result["params"] = params
return result
def main():
args = parse_args()
logger = MMLogger.get_instance(name="MMLogger")
result = inference(args, logger)
split_line = "=" * 30
ori_shape = result["ori_shape"]
pad_shape = result["pad_shape"]
flops = result["flops"]
torchprofile_flops = result["torchprofile_flops"]
params = result["params"]
compute_type = result["compute_type"]
if pad_shape != ori_shape:
print(
f"{split_line}\nUse size divisor set input shape "
f"from {ori_shape} to {pad_shape}"
)
print(
f"{split_line}\nCompute type: {compute_type}\n"
f"Input shape: {pad_shape}\nFlops: {flops}\nFlops (torchprofile): {torchprofile_flops}\n"
f"Params: {params}\n{split_line}"
)
print(
"!!!Please be cautious if you use the results in papers. "
"You may need to check if all ops are supported and verify "
"that the flops computation is correct."
)
if __name__ == "__main__":
main()