-
Notifications
You must be signed in to change notification settings - Fork 2
/
compute_flops.py
42 lines (32 loc) · 1.46 KB
/
compute_flops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# -*- coding: utf-8 -*-
# ------------------------------------------------------------------------
# Copyright (c) 2021 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
from collections import Counter
import tqdm
from fvcore.nn import flop_count_table # can also try flop_count_str
from detectron2.utils.analysis import FlopCountAnalysis
from main import get_args_parser as get_main_args_parser
from models import build_SAPDETR
from datasets import build_dataset
def do_flop():
main_args = get_main_args_parser().parse_args()
dataset = build_dataset('val', main_args)
model, _, _ = build_SAPDETR(main_args)
model.cuda()
model.eval()
counts = Counter()
total_flops = []
for idx, data in zip(tqdm.trange(100), dataset): # noqa
flops = FlopCountAnalysis(model, [data[0].cuda()])
if idx > 0:
flops.unsupported_ops_warnings(False).uncalled_modules_warnings(False)
counts += flops.by_operator()
total_flops.append(flops.total())
print("Flops table computed from only one input sample:\n" + flop_count_table(flops))
print("Average GFlops for each type of operators:\n"+ str([(k, v / (idx + 1) / 1e9) for k, v in counts.items()]))
print("Total GFlops: {:.1f}±{:.1f}".format(np.mean(total_flops) / 1e9, np.std(total_flops) / 1e9))
if __name__ == "__main__":
do_flop()