Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] NMS update #957

Merged
merged 31 commits into from
May 31, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8b1f267
Add score_threshold and max_num to NMS
SemyonBevzuk Apr 15, 2021
925fae1
Fix codestyle
SemyonBevzuk Apr 15, 2021
3cee61c
Merge branch 'master' into nms_update
SemyonBevzuk Apr 16, 2021
fd8b082
Fix codestyle
SemyonBevzuk Apr 16, 2021
a6eae23
Fix inds in nms
SemyonBevzuk Apr 21, 2021
ae469a7
Update nms docstring
SemyonBevzuk Apr 21, 2021
23a465b
Move score_threshold and max_num arguments
SemyonBevzuk Apr 21, 2021
61c7286
Fix args order in docstring
SemyonBevzuk Apr 21, 2021
de02886
Merge branch 'master' into nms_update
RunningLeon Apr 22, 2021
8b4e7de
fix lint of c++ file
RunningLeon Apr 22, 2021
a9e45d6
Merge remote-tracking branch 'upstream/master' into nms_update
SemyonBevzuk Apr 22, 2021
5c083b8
Remove torch.onnx.is_in_onnx_export() and add max_num to batched_nms …
SemyonBevzuk Apr 23, 2021
79e6941
Rewrote max_num handling in NMSop.symbolic
SemyonBevzuk Apr 26, 2021
f8b1a7b
Added processing max_output_boxes_per_class when exporting to TensorRT
SemyonBevzuk Apr 27, 2021
9dba3c5
Added score_threshold and max_num for NMS in test_onnx.py and test_te…
SemyonBevzuk Apr 27, 2021
1355f5c
Remove _is_value(max_num)
SemyonBevzuk Apr 30, 2021
bc7fbb7
Merge remote-tracking branch 'upstream/master' into nms_update
SemyonBevzuk Apr 30, 2021
d2702b8
fix ci errors with torch==1.3.1
RunningLeon May 6, 2021
dc93be1
Update test_batched_nms in test_nms.py
SemyonBevzuk May 7, 2021
95f1bbb
Added tests for preprocess_onnx
SemyonBevzuk May 14, 2021
6d415c9
Fix
SemyonBevzuk May 14, 2021
9447eb6
Moved 'test_tensorrt_preprocess.py' and 'preprocess', updated 'remove…
SemyonBevzuk May 14, 2021
e735646
Update mmcv/tensorrt/__init__.py
SemyonBevzuk May 17, 2021
9d9c169
Fix segfault torch==1.3.1 (remove onnx.checker.check_model)
SemyonBevzuk May 18, 2021
8b03816
Returned 'onnx.checker.check_model' with torch version check
SemyonBevzuk May 19, 2021
5cd8a54
Changed torch version from 1.3.1 to 1.4.0
SemyonBevzuk May 19, 2021
5f74a4c
update version check
RunningLeon May 21, 2021
7ba082e
remove check for onnx
RunningLeon May 21, 2021
c9ac555
merge master and fix conflicts
RunningLeon May 24, 2021
f9fc45e
Merge branch 'nms_update' of github.com:SemyonBevzuk/mmcv into nms_up…
RunningLeon May 24, 2021
85576ea
merge master and resolve conflicts
RunningLeon May 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions mmcv/tensorrt/tensorrt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ def preprocess_onnx(onnx_model):

init_dict = {_.name: _ for _ in initializers}

nodes_to_remove = []
nodes_name_to_remove = set()

def is_node_without_output(name):
for node_name, node in node_dict.items():
if node not in nodes_to_remove:
if node_name not in nodes_name_to_remove:
if name in node.input:
return False
return True

def mark_nodes_to_remove(name):
node = node_dict[name]
nodes_to_remove.append(node)
nodes_name_to_remove.add(name)
for input_node_name in node.input:
if is_node_without_output(input_node_name):
mark_nodes_to_remove(input_node_name)
Expand Down Expand Up @@ -85,13 +85,16 @@ def parse_data(name, typ, default_value=0):
if len(node_inputs) >= 3:
max_output_boxes_per_class = parse_data(
node_inputs[2], np.int64, max_output_boxes_per_class)
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
mark_nodes_to_remove(node_inputs[2])

if len(node_inputs) >= 4:
iou_threshold = parse_data(node_inputs[3], np.float32,
iou_threshold)
mark_nodes_to_remove(node_inputs[3])

if len(node_inputs) >= 5:
score_threshold = parse_data(node_inputs[4], np.float32)
mark_nodes_to_remove(node_inputs[4])

new_node = onnx.helper.make_node(
'NonMaxSuppression',
Expand All @@ -110,8 +113,8 @@ def parse_data(name, typ, default_value=0):
nodes.insert(idx, new_node)
nodes.remove(node)

for node in nodes_to_remove:
nodes.remove(node)
for node_name in nodes_name_to_remove:
nodes.remove(node_dict[node_name])

return onnx_model

Expand Down
74 changes: 74 additions & 0 deletions tests/test_tensorrt_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
from functools import wraps

import onnx
import torch

from mmcv.ops import nms
from mmcv.tensorrt.tensorrt_utils import preprocess_onnx
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved


def remove_tmp_file(func):

@wraps(func)
def wrapper(*args, **kwargs):
onnx_file = 'tmp.onnx'
kwargs['onnx_file'] = onnx_file
result = func(*args, **kwargs)
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
if os.path.exists(onnx_file):
os.remove(onnx_file)
return result

return wrapper


@remove_tmp_file
def export_nms_module_to_onnx(module, onnx_file):
torch_model = module()
torch_model.eval()

input = (torch.rand([100, 4], dtype=torch.float32),
torch.rand([100], dtype=torch.float32))

torch.onnx.export(
torch_model,
input,
onnx_file,
opset_version=11,
input_names=['boxes', 'scores'],
output_names=['output'])

onnx_model = onnx.load(onnx_file)
onnx.checker.check_model(onnx_model)
return onnx_model


def test_can_handle_nms_with_constant_maxnum():

class NMS_with_const_maxnum(torch.nn.Module):
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, boxes, scores):
return nms(boxes, scores, iou_threshold=0.4, max_num=10)

onnx_model = export_nms_module_to_onnx(NMS_with_const_maxnum)
preprocess_onnx_model = preprocess_onnx(onnx_model)
for node in preprocess_onnx_model.graph.node:
if 'NonMaxSuppression' in node.name:
assert len(node.attribute) == 5, 'The NMS must have 5 attributes.'


def test_can_handle_nms_with_undefined_maxnum():

class NMS_with_const_maxnum(torch.nn.Module):
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, boxes, scores):
return nms(boxes, scores, iou_threshold=0.4)

onnx_model = export_nms_module_to_onnx(NMS_with_const_maxnum)
preprocess_onnx_model = preprocess_onnx(onnx_model)
for node in preprocess_onnx_model.graph.node:
if 'NonMaxSuppression' in node.name:
assert len(node.attribute) == 5, \
'The NMS must have 5 attributes.'
assert node.attribute[2].i > 0, \
'The max_output_boxes_per_class is not defined correctly.'