-
Notifications
You must be signed in to change notification settings - Fork 37
/
benchmark.py
executable file
·151 lines (114 loc) · 4.1 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import math
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from pathlib import Path
from pprint import pprint
from sotabencheval.image_classification import ImageNetEvaluator
from torchvision.datasets import ImageNet
from tqdm.autonotebook import tqdm
from glasses.models import *
from glasses.models.AutoModel import AutoModel
from glasses.models.AutoTransform import AutoTransform
models = AutoModel.pretrained_models
batch_sizes = {"efficientnet_b0": 256, "efficientnet_b1": 128, "efficientnet_b5": 8}
# code stolen from https://github.com/ansleliu/EfficientNet.PyTorch/blob/master/eval.py
# if you are using it, show some love an star his repo!``
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_img_id(image_name):
return image_name.split("/")[-1].replace(".JPEG", "")
def benchmark(
model: nn.Module, transform, batch_size=64, device=device, fast: bool = False
):
valid_dataset = ImageNet(
root="/home/zuppif/Downloads/ImageNet", split="val", transform=transform
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=12,
pin_memory=True,
)
evaluator = ImageNetEvaluator(model_name="test", paper_arxiv_id="1905.11946")
model.eval().to(device)
num_batches = int(
math.ceil(len(valid_loader.dataset) / float(valid_loader.batch_size))
)
start = time.time()
with torch.no_grad():
pbar = tqdm(np.arange(num_batches), leave=False)
for i_val, (images, labels) in enumerate(valid_loader):
images = images.to(device)
labels = torch.squeeze(labels.to(device))
net_out = model(images)
image_ids = [
get_img_id(img[0])
for img in valid_loader.dataset.imgs[
i_val
* valid_loader.batch_size : (i_val + 1)
* valid_loader.batch_size
]
]
evaluator.add(dict(zip(image_ids, list(net_out.cpu().numpy()))))
pbar.set_description(f"f1={evaluator.top1.avg:.2f}")
pbar.update(1)
if fast:
break
pbar.close()
stop = time.time()
if fast:
return evaluator.top1.avg, None, None
else:
res = evaluator.get_results()
return res["Top 1 Accuracy"], res["Top 5 Accuracy"], stop - start
def benchmark_all() -> pd.DataFrame:
save_path = Path("./benchmark.csv")
df = pd.DataFrame()
if save_path.exists():
df = pd.read_csv(str(save_path), index_col=0)
index = []
records = []
bar = tqdm(models)
try:
for key in bar:
if key not in df.index:
try:
model = AutoModel.from_pretrained(key)
tr = AutoTransform.from_name(key)
batch_size = 64
if key in batch_sizes:
batch_size = batch_sizes[key]
bar.set_description(
f"{key}, size={tr.transforms[0].size}, batch_size={batch_size}"
)
top1, top5, time = benchmark(model.to(device), tr, batch_size)
index.append(key)
data = {
"top1": top1,
"top5": top5,
"time": time,
"batch_size": batch_size,
}
pprint(data)
records.append(data)
except KeyError:
continue
except Exception as e:
print(e)
pass
if len(records) > 0:
new_df = pd.DataFrame.from_records(records, index=index)
if df is not None:
df = pd.concat([df, new_df])
else:
df = new_df
df.to_csv("./benchmark.csv")
mk = df.sort_values("top1", ascending=False).to_markdown()
with open("./benchmark.md", "w") as f:
f.write(mk)
return df
if __name__ == "__main__":
benchmark_all()