-
Notifications
You must be signed in to change notification settings - Fork 5
/
benchmark.py
103 lines (84 loc) · 3.12 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import random
import string
import sys
import threading
import time
from collections import namedtuple
from numpy import mean
from service.client import BertClient
from service.server import BertServer
PORT = 5557
def tprint(msg):
"""like print, but won't get newlines confused with multiple threads"""
sys.stdout.write(msg + '\n')
sys.stdout.flush()
class BenchmarkClient(threading.Thread):
def __init__(self, args):
super().__init__()
self.batch = [''.join(random.choices(string.ascii_uppercase + string.digits,
k=args.max_seq_len)) for _ in range(args.client_batch_size)]
self.num_repeat = args.num_repeat
self.avg_time = 0
def run(self):
time_all = []
bc = BertClient(port=PORT, show_server_config=False)
for _ in range(self.num_repeat):
start_t = time.perf_counter()
bc.encode(self.batch)
time_all.append(time.perf_counter() - start_t)
print(time_all)
self.avg_time = mean(time_all)
if __name__ == '__main__':
common = {
'model_dir': '/data/cips/data/lab/data/model/chinese_L-12_H-768_A-12',
'num_worker': 2,
'num_repeat': 5,
'port': PORT,
'max_seq_len': 40,
'client_batch_size': 2048,
'max_batch_size': 256,
'num_client': 1
}
experiments = {
'client_batch_size': [1, 4, 8, 16, 64, 256, 512, 1024, 2048, 4096],
'max_batch_size': [32, 64, 128, 256, 512],
'max_seq_len': [20, 40, 80, 160, 320],
'num_client': [2, 4, 8, 16, 32],
}
fp = open('benchmark.result', 'w')
for var_name, var_lst in experiments.items():
# set common args
args = namedtuple('args', ','.join(common.keys()))
for k, v in common.items():
setattr(args, k, v)
avg_speed = []
for var in var_lst:
# override exp args
setattr(args, var_name, var)
server = BertServer(args)
server.start()
# sleep until server is ready
time.sleep(15)
all_clients = [BenchmarkClient(args) for _ in range(args.num_client)]
tprint('num_client: %d' % len(all_clients))
for bc in all_clients:
bc.start()
all_thread_speed = []
for bc in all_clients:
bc.join()
cur_speed = args.client_batch_size / bc.avg_time
all_thread_speed.append(cur_speed)
max_speed = int(max(all_thread_speed))
min_speed = int(min(all_thread_speed))
t_avg_speed = int(mean(all_thread_speed))
tprint('%s: %5d\t%.3f\t%d/s' % (var_name, var, bc.avg_time, t_avg_speed))
tprint('max speed: %d\t min speed: %d' % (max_speed, min_speed))
avg_speed.append(t_avg_speed)
server.close()
fp.write('#### Speed wrt. `%s`\n\n' % var_name)
fp.write('|`%s`|seqs/s|\n' % var_name)
fp.write('|---|---|\n')
for i, j in zip(var_lst, avg_speed):
fp.write('|%d|%d|\n' % (i, j))
fp.flush()
fp.close()