-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathbenchmark.py
354 lines (321 loc) · 13.6 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import multiprocessing as mp
from time import time
import torch
def parse_arguments():
parser = argparse.ArgumentParser(
description='Benchmark TensorRT-LLM models.')
parser.add_argument('-m',
'--model',
type=str,
default="dec",
choices=["dec", "enc", "enc-dec"],
help='Specify type of the model you want to benchmark. '
'Choose model between dec/enc/enc-dec.')
parser.add_argument('--batch_size',
type=str,
default="8",
help=('Specify batch size(s) you want to benchmark. '
'Multiple batch sizes can be separated by \";\", '
'example: \"1;8;64\".'))
parser.add_argument(
'--input_len',
type=str,
default="128",
help=('Specify input length(s) you want to benchmark, '
'this option is mainly for BERT. '
'Multiple input lengths can be separated by \";\", '
'example: \"20;60;128\".'))
parser.add_argument(
'--input_output_len',
type=str,
default="128,20",
help=('Specify input-output length(s) you want to benchmark, '
'this option is mainly for GPT and GPT-like models. '
'Multiple input lengths can be separated by \";\", '
'example: \"60,20;128,20\".'))
parser.add_argument(
'--dtype',
type=str,
default='float16',
choices=['float16', 'bfloat16', 'float32'],
help='Choose data type between float16/bfloat16/float32.')
parser.add_argument('--num_beams',
type=int,
default="1",
help=('Specify number of beams you want to benchmark.'))
parser.add_argument('--top_k',
type=int,
default="1",
help=('Specify Top-K value of decoding.'))
parser.add_argument('--top_p',
type=float,
default="0",
help=('Specify Top-P value of decoding.'))
parser.add_argument(
'--input_timing_cache',
type=str,
default=None,
help=
'The path to read timing cache, will be ignored if the file does not exist'
)
parser.add_argument('--output_timing_cache',
type=str,
default='model.cache',
help='The path to write timing cache')
parser.add_argument(
'--log_level',
type=str,
default="error",
choices=['verbose', 'info', 'warning', 'error', 'internal_error'],
help=
'Choose log level between verbose/info/warning/error/internal_error.')
parser.add_argument(
'--warm_up',
type=int,
default=2,
help='Specify warm up iterations before benchmark starts.')
parser.add_argument(
'--num_runs',
type=int,
default=10,
help='Minimal number of iterations to run during benchmarking.')
parser.add_argument(
'--duration',
type=int,
default=60,
help='Minimal duration of iterations to measure in seconds.')
parser.add_argument(
'--engine_dir',
type=str,
default=None,
required=True,
help=
('If this option is specified, instead of building engines on-air before benchmarking, '
'the engines contained in the engine_dir will be used.'))
parser.add_argument(
'--gpu_weights_percent',
type=str,
default="1.0",
help='Specify the percentage of weights that reside on GPU (from 0 to 1).'
'Multiple percentages can be separated by \";\", '
'example: \"0;0.5;1\".')
parser.add_argument('--csv',
default=False,
action="store_true",
help='Output in CSV format.')
parser.add_argument('--enable_cuda_graph',
default=False,
action='store_true',
help='Execute GPT session with CUDA graph.')
parser.add_argument(
'--quantization',
type=str,
default=None,
choices=[
'fp8', 'fp8_gemm', 'fp8_kv_cache', 'int8_sq_per_tensor',
'int8_sq_per_token_channel', 'int8_weight_only', 'int4_weight_only',
'int4_weight_only_awq', 'int4_weight_only_gptq',
'int8_sq_per_channel_ootb'
],
help="Optimize the model with specified quantization recipe")
parser.add_argument(
'--dump_profile',
default=False,
action='store_true',
help="Print profile information per layer (default = disabled)")
parser.add_argument(
'--dump_layer_info',
default=False,
action='store_true',
help=
"Print layer information of the engine to console (default = disabled)")
return parser.parse_args()
def main(args):
# We import tensorrt_llm here because MPI is initialized when
# tensorrt_llm is imported, but mpi4py does not work well with
# the start method `spawn` of Python multiprocessing,
# so we set the start method first, then initialize MPI.
from benchmark_profiler import BenchmarkProfiler
from bert_benchmark import BERTBenchmark
from enc_dec_benchmark import EncDecBenchmark
from gpt_benchmark import GPTBenchmark
import tensorrt_llm
from tensorrt_llm.logger import logger
logger.set_level(args.log_level)
# Batch size
batch_size_options = args.batch_size.split(';')
batch_size_options = [int(i) for i in batch_size_options]
# Input length (for BERT-like models)
input_len_options = args.input_len.split(';')
input_len_options = [int(i) for i in input_len_options]
# Input-output length combination (for GPT-like models and enc_dec models)
in_out_len_options = args.input_output_len.split(';')
in_out_len_options = [[int(i) for i in io.split(',')]
for io in in_out_len_options]
# GPU weights percentage ratios
gpu_weights_percents = [
float(r) for r in args.gpu_weights_percent.split(";")
]
for percent in gpu_weights_percents:
if percent < 0 or percent > 1:
raise Exception(
f"--gpu_weights_percent only accepts values between 0.0 and 1.0."
)
rank = tensorrt_llm.mpi_rank()
world_size = tensorrt_llm.mpi_world_size()
# TODO: Re-enable memory monitor for multi-gpu benchmarks.
# Current Mem Monitor will cause benchmark script hang
# because MPI does not work well with multiprocessing.
disable_mem_monitor = world_size > 1
if not disable_mem_monitor:
from mem_monitor import MemoryMonitor
benchmark_profiler = None
if args.model == "dec":
benchmark_profiler = BenchmarkProfiler()
benchmarker = GPTBenchmark(args, batch_size_options, in_out_len_options,
gpu_weights_percents, rank, world_size)
elif args.model == "enc":
benchmarker = BERTBenchmark(args, batch_size_options, input_len_options,
gpu_weights_percents, rank, world_size)
elif args.model == "enc-dec":
benchmarker = EncDecBenchmark(args, batch_size_options,
in_out_len_options, gpu_weights_percents,
rank, world_size)
else:
raise Exception(f'Unexpected model: {args.model}')
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
benchmarker.print_report_header(args.csv,
benchmark_profiler=benchmark_profiler)
for config in benchmarker.get_config():
try:
# We pass in config instead of the gpu_weights_percent here to keep this benchmark script
# agnostic to the length and contents of the config.
benchmarker.set_weight_streaming(config)
inputs = benchmarker.prepare_inputs(config)
except torch.cuda.OutOfMemoryError as e:
logger.error(
f'Exception {e} caught while allocating memory; skipping {config}'
)
continue
torch.cuda.empty_cache()
latencies = []
# Disable Host memory monitor when cuda graph is enabled for cuda graph performance.
disable_host_mem_monitor = False
if args.enable_cuda_graph:
logger.warning(
'Disable host memory monitor when cuda graph is enabled.')
disable_host_mem_monitor = True
if not disable_mem_monitor:
memory_monitor = MemoryMonitor(
disable_host_mem_monitor=disable_host_mem_monitor)
memory_monitor.start()
iter_idx = 0
try:
# Warm up
for _ in range(args.warm_up):
benchmarker.run(inputs, config)
logger.info('Warm up done. Start benchmarking.')
if benchmark_profiler is not None:
benchmark_profiler.clean()
benchmark_profiler.start()
cur_duration = 0
start_time = time()
while iter_idx < args.num_runs or cur_duration < args.duration:
start.record()
benchmarker.run(inputs,
config,
benchmark_profiler=benchmark_profiler)
end.record()
torch.cuda.synchronize()
latencies.append(start.elapsed_time(end))
iter_idx += 1
cur_duration = round(time() - start_time, 3)
logger.info(
f'Benchmarking done. Iteration: {iter_idx}, duration: {cur_duration} sec.'
)
except Exception as e:
logger.error("Found exception during benchmarking",
e.with_traceback())
if not disable_mem_monitor:
memory_monitor.kill()
raise e
if not disable_mem_monitor:
memory_monitor.stop()
_, peak_gpu_used = memory_monitor.get_peak_memory_usage("GiB")
peak_gpu_used = round(peak_gpu_used, 3)
else:
peak_gpu_used = 0.0
if benchmark_profiler is not None:
benchmark_profiler.add_aux_info('iter_count', iter_idx)
benchmark_profiler.stop()
# Print latencies to make it easier to check perf stability.
if len(latencies) <= 20:
latencies_str = str(latencies)
else:
latencies_str = ("[" + ", ".join([str(l) for l in latencies[:10]]) +
"..." +
", ".join([str(l) for l in latencies[-10:]]) + "]")
logger.info(f"Latencies: {latencies_str}")
latency = round(sum(latencies) / iter_idx, 3)
latencies.sort()
percentile95 = round(latencies[int(iter_idx * 0.95)], 3)
percentile99 = round(latencies[int(iter_idx * 0.99)], 3)
benchmarker.report(config,
latency,
percentile95,
percentile99,
peak_gpu_used,
csv=args.csv,
benchmark_profiler=benchmark_profiler)
# Rerun for dumping profile per layer.
if args.dump_profile and benchmark_profiler is not None:
benchmark_profiler.set_recording_perf_profile(True)
logger.info(f'Dump profile information per layer')
iter_idx = 0
try:
# Warm up
for _ in range(args.warm_up):
benchmarker.run(inputs, config)
if benchmark_profiler is not None:
benchmark_profiler.clean()
benchmark_profiler.start()
cur_duration = 0
start_time = time()
while iter_idx < args.num_runs or cur_duration < args.duration:
start.record()
benchmarker.run(inputs,
config,
benchmark_profiler=benchmark_profiler)
end.record()
torch.cuda.synchronize()
latencies.append(start.elapsed_time(end))
iter_idx += 1
cur_duration = round(time() - start_time, 3)
benchmarker.report_profiler(
benchmark_profiler=benchmark_profiler)
except Exception as e:
logger.error("Found exception during benchmarking",
e.with_traceback())
if not disable_mem_monitor:
memory_monitor.kill()
raise e
if __name__ == '__main__':
mp.set_start_method('spawn')
args = parse_arguments()
main(args)