-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
demo_trt.py
201 lines (157 loc) · 7.02 KB
/
demo_trt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import sys
import os
import time
import argparse
import numpy as np
import cv2
# from PIL import Image
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
from tool.utils import *
try:
# Sometimes python2 does not understand FileNotFoundError
FileNotFoundError
except NameError:
FileNotFoundError = IOError
def GiB(val):
return val * 1 << 30
def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[]):
'''
Parses sample arguments.
Args:
description (str): Description of the sample.
subfolder (str): The subfolder containing data relevant to this sample
find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path.
Returns:
str: Path of data directory.
Raises:
FileNotFoundError
'''
# Standard command-line arguments for all samples.
kDEFAULT_DATA_ROOT = os.path.join(os.sep, "usr", "src", "tensorrt", "data")
parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-d", "--datadir", help="Location of the TensorRT sample data directory.", default=kDEFAULT_DATA_ROOT)
args, unknown_args = parser.parse_known_args()
# If data directory is not specified, use the default.
data_root = args.datadir
# If the subfolder exists, append it to the path, otherwise use the provided path as-is.
subfolder_path = os.path.join(data_root, subfolder)
data_path = subfolder_path
if not os.path.exists(subfolder_path):
print("WARNING: " + subfolder_path + " does not exist. Trying " + data_root + " instead.")
data_path = data_root
# Make sure data directory exists.
if not (os.path.exists(data_path)):
raise FileNotFoundError(data_path + " does not exist. Please provide the correct data path with the -d option.")
# Find all requested files.
for index, f in enumerate(find_files):
find_files[index] = os.path.abspath(os.path.join(data_path, f))
if not os.path.exists(find_files[index]):
raise FileNotFoundError(find_files[index] + " does not exist. Please provide the correct data path with the -d option.")
return data_path, find_files
# Simple helper data class that's a little nicer to use than a 2-tuple.
class HostDeviceMem(object):
def __init__(self, host_mem, device_mem):
self.host = host_mem
self.device = device_mem
def __str__(self):
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
def __repr__(self):
return self.__str__()
# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
def allocate_buffers(engine, batch_size):
inputs = []
outputs = []
bindings = []
stream = cuda.Stream()
for binding in engine:
size = trt.volume(engine.get_binding_shape(binding)) * batch_size
dims = engine.get_binding_shape(binding)
# in case batch dimension is -1 (dynamic)
if dims[0] < 0:
size *= -1
dtype = trt.nptype(engine.get_binding_dtype(binding))
# Allocate host and device buffers
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
# Append the device buffer to device bindings.
bindings.append(int(device_mem))
# Append to the appropriate list.
if engine.binding_is_input(binding):
inputs.append(HostDeviceMem(host_mem, device_mem))
else:
outputs.append(HostDeviceMem(host_mem, device_mem))
return inputs, outputs, bindings, stream
# This function is generalized for multiple inputs/outputs.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference(context, bindings, inputs, outputs, stream):
# Transfer input data to the GPU.
[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
# Run inference.
context.execute_async(bindings=bindings, stream_handle=stream.handle)
# Transfer predictions back from the GPU.
[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
# Synchronize the stream
stream.synchronize()
# Return only the host outputs.
return [out.host for out in outputs]
TRT_LOGGER = trt.Logger()
def main(engine_path, image_path, image_size):
with get_engine(engine_path) as engine, engine.create_execution_context() as context:
buffers = allocate_buffers(engine, 1)
IN_IMAGE_H, IN_IMAGE_W = image_size
context.set_binding_shape(0, (1, 3, IN_IMAGE_H, IN_IMAGE_W))
image_src = cv2.imread(image_path)
num_classes = 80
for i in range(2): # This 'for' loop is for speed check
# Because the first iteration is usually longer
boxes = detect(context, buffers, image_src, image_size, num_classes)
if num_classes == 20:
namesfile = 'data/voc.names'
elif num_classes == 80:
namesfile = 'data/coco.names'
else:
namesfile = 'data/names'
class_names = load_class_names(namesfile)
plot_boxes_cv2(image_src, boxes[0], savename='predictions_trt.jpg', class_names=class_names)
def get_engine(engine_path):
# If a serialized engine exists, use it instead of building an engine.
print("Reading engine from file {}".format(engine_path))
with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
return runtime.deserialize_cuda_engine(f.read())
def detect(context, buffers, image_src, image_size, num_classes):
IN_IMAGE_H, IN_IMAGE_W = image_size
ta = time.time()
# Input
resized = cv2.resize(image_src, (IN_IMAGE_W, IN_IMAGE_H), interpolation=cv2.INTER_LINEAR)
img_in = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
img_in = np.transpose(img_in, (2, 0, 1)).astype(np.float32)
img_in = np.expand_dims(img_in, axis=0)
img_in /= 255.0
img_in = np.ascontiguousarray(img_in)
print("Shape of the network input: ", img_in.shape)
# print(img_in)
inputs, outputs, bindings, stream = buffers
print('Length of inputs: ', len(inputs))
inputs[0].host = img_in
trt_outputs = do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
print('Len of outputs: ', len(trt_outputs))
trt_outputs[0] = trt_outputs[0].reshape(1, -1, 1, 4)
trt_outputs[1] = trt_outputs[1].reshape(1, -1, num_classes)
tb = time.time()
print('-----------------------------------')
print(' TRT inference time: %f' % (tb - ta))
print('-----------------------------------')
boxes = post_processing(img_in, 0.4, 0.6, trt_outputs)
return boxes
if __name__ == '__main__':
engine_path = sys.argv[1]
image_path = sys.argv[2]
if len(sys.argv) < 4:
image_size = (416, 416)
elif len(sys.argv) < 5:
image_size = (int(sys.argv[3]), int(sys.argv[3]))
else:
image_size = (int(sys.argv[3]), int(sys.argv[4]))
main(engine_path, image_path, image_size)