Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] NMS update #957

Merged
merged 31 commits into from
May 31, 2021
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8b1f267
Add score_threshold and max_num to NMS
SemyonBevzuk Apr 15, 2021
925fae1
Fix codestyle
SemyonBevzuk Apr 15, 2021
3cee61c
Merge branch 'master' into nms_update
SemyonBevzuk Apr 16, 2021
fd8b082
Fix codestyle
SemyonBevzuk Apr 16, 2021
a6eae23
Fix inds in nms
SemyonBevzuk Apr 21, 2021
ae469a7
Update nms docstring
SemyonBevzuk Apr 21, 2021
23a465b
Move score_threshold and max_num arguments
SemyonBevzuk Apr 21, 2021
61c7286
Fix args order in docstring
SemyonBevzuk Apr 21, 2021
de02886
Merge branch 'master' into nms_update
RunningLeon Apr 22, 2021
8b4e7de
fix lint of c++ file
RunningLeon Apr 22, 2021
a9e45d6
Merge remote-tracking branch 'upstream/master' into nms_update
SemyonBevzuk Apr 22, 2021
5c083b8
Remove torch.onnx.is_in_onnx_export() and add max_num to batched_nms …
SemyonBevzuk Apr 23, 2021
79e6941
Rewrote max_num handling in NMSop.symbolic
SemyonBevzuk Apr 26, 2021
f8b1a7b
Added processing max_output_boxes_per_class when exporting to TensorRT
SemyonBevzuk Apr 27, 2021
9dba3c5
Added score_threshold and max_num for NMS in test_onnx.py and test_te…
SemyonBevzuk Apr 27, 2021
1355f5c
Remove _is_value(max_num)
SemyonBevzuk Apr 30, 2021
bc7fbb7
Merge remote-tracking branch 'upstream/master' into nms_update
SemyonBevzuk Apr 30, 2021
d2702b8
fix ci errors with torch==1.3.1
RunningLeon May 6, 2021
dc93be1
Update test_batched_nms in test_nms.py
SemyonBevzuk May 7, 2021
95f1bbb
Added tests for preprocess_onnx
SemyonBevzuk May 14, 2021
6d415c9
Fix
SemyonBevzuk May 14, 2021
9447eb6
Moved 'test_tensorrt_preprocess.py' and 'preprocess', updated 'remove…
SemyonBevzuk May 14, 2021
e735646
Update mmcv/tensorrt/__init__.py
SemyonBevzuk May 17, 2021
9d9c169
Fix segfault torch==1.3.1 (remove onnx.checker.check_model)
SemyonBevzuk May 18, 2021
8b03816
Returned 'onnx.checker.check_model' with torch version check
SemyonBevzuk May 19, 2021
5cd8a54
Changed torch version from 1.3.1 to 1.4.0
SemyonBevzuk May 19, 2021
5f74a4c
update version check
RunningLeon May 21, 2021
7ba082e
remove check for onnx
RunningLeon May 21, 2021
c9ac555
merge master and fix conflicts
RunningLeon May 24, 2021
f9fc45e
Merge branch 'nms_update' of github.com:SemyonBevzuk/mmcv into nms_up…
RunningLeon May 24, 2021
85576ea
merge master and resolve conflicts
RunningLeon May 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions mmcv/ops/nms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import sys

import numpy as np
import torch
Expand All @@ -15,13 +14,27 @@
class NMSop(torch.autograd.Function):

@staticmethod
def forward(ctx, bboxes, scores, iou_threshold, offset):
def forward(ctx, bboxes, scores, iou_threshold, offset, score_threshold,
max_num):
is_filtering_by_score = score_threshold > 0
if is_filtering_by_score:
valid_mask = scores > score_threshold
bboxes, scores = bboxes[valid_mask], scores[valid_mask]
valid_inds = torch.nonzero(
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
valid_mask, as_tuple=False).squeeze(dim=1)

inds = ext_module.nms(
bboxes, scores, iou_threshold=float(iou_threshold), offset=offset)

if max_num > 0:
inds = inds[:max_num]
if is_filtering_by_score:
inds = valid_inds[inds]
return inds

@staticmethod
def symbolic(g, bboxes, scores, iou_threshold, offset):
def symbolic(g, bboxes, scores, iou_threshold, offset, score_threshold,
max_num):
from ..onnx import is_custom_op_loaded
has_custom_op = is_custom_op_loaded()
# TensorRT nms plugin is aligned with original nms in ONNXRuntime
Expand All @@ -35,16 +48,28 @@ def symbolic(g, bboxes, scores, iou_threshold, offset):
offset_i=int(offset))
else:
from torch.onnx.symbolic_opset9 import select, squeeze, unsqueeze
from ..onnx.onnx_utils.symbolic_helper import _size_helper

boxes = unsqueeze(g, bboxes, 0)
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
max_output_per_class = g.op(
'Constant',
value_t=torch.tensor([sys.maxsize], dtype=torch.long))

if max_num > 0:
max_num = g.op(
'Constant',
value_t=torch.tensor(max_num, dtype=torch.long))
else:
dim = g.op('Constant', value_t=torch.tensor(0))
max_num = _size_helper(g, bboxes, dim)
max_output_per_class = max_num
iou_threshold = g.op(
'Constant',
value_t=torch.tensor([iou_threshold], dtype=torch.float))
score_threshold = g.op(
'Constant',
value_t=torch.tensor([score_threshold], dtype=torch.float))
nms_out = g.op('NonMaxSuppression', boxes, scores,
max_output_per_class, iou_threshold)
max_output_per_class, iou_threshold,
score_threshold)
return squeeze(
g,
select(
Expand Down Expand Up @@ -90,7 +115,7 @@ def symbolic(g, boxes, scores, iou_threshold, sigma, min_score, method,


@deprecated_api_warning({'iou_thr': 'iou_threshold'})
def nms(boxes, scores, iou_threshold, offset=0):
def nms(boxes, scores, iou_threshold, offset=0, score_threshold=0, max_num=-1):
"""Dispatch to either CPU or GPU NMS implementations.

The input can be either torch tensor or numpy array. GPU NMS will be used
Expand All @@ -102,6 +127,8 @@ def nms(boxes, scores, iou_threshold, offset=0):
scores (torch.Tensor or np.ndarray): scores in shape (N, ).
iou_threshold (float): IoU threshold for NMS.
offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset).
score_threshold (float): score threshold for NMS.
max_num (int): maximum number of boxes after NMS.

Returns:
tuple: kept dets(boxes and scores) and indice, which is always the \
Expand Down Expand Up @@ -141,7 +168,8 @@ def nms(boxes, scores, iou_threshold, offset=0):
}
inds = ext_module.nms(*indata_list, **indata_dict)
else:
inds = NMSop.apply(boxes, scores, iou_threshold, offset)
inds = NMSop.apply(boxes, scores, iou_threshold, offset,
score_threshold, max_num)
dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1)
if is_numpy:
dets = dets.cpu().numpy()
Expand Down Expand Up @@ -285,6 +313,7 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
# Some type of nms would reweight the score, such as SoftNMS
scores = dets[:, 4]
else:
max_num = nms_cfg_.pop('max_num', -1)
total_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
# Some type of nms would reweight the score, such as SoftNMS
scores_after_nms = scores.new_zeros(scores.size())
Expand All @@ -294,10 +323,16 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
total_mask[mask[keep]] = True
scores_after_nms[mask[keep]] = dets[:, -1]
keep = total_mask.nonzero(as_tuple=False).view(-1)

scores, inds = scores_after_nms[keep].sort(descending=True)
keep = keep[inds]
boxes = boxes[keep]

if max_num > 0:
keep = keep[:max_num]
boxes = boxes[:max_num]
scores = scores[:max_num]

return torch.cat([boxes, scores[:, None]], -1), keep


Expand Down
31 changes: 23 additions & 8 deletions mmcv/tensorrt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
# flake8: noqa
from .init_plugins import is_tensorrt_plugin_loaded, load_tensorrt_plugin
from .tensorrt_utils import (TRTWraper, load_trt_engine, onnx2trt,
save_trt_engine)
from .preprocess import preprocess_onnx

# load tensorrt plugin lib
load_tensorrt_plugin()

__all__ = [
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper',
'is_tensorrt_plugin_loaded'
]
def is_tensorrt_available():
try:
import tensorrt
del tensorrt
return True
except ModuleNotFoundError:
return False


__all__ = []

if is_tensorrt_available():
from .tensorrt_utils import (TRTWraper, load_trt_engine, onnx2trt,
save_trt_engine)

# load tensorrt plugin lib
load_tensorrt_plugin()

__all__.append(
['onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper'])

__all__.append(['is_tensorrt_plugin_loaded', 'preprocess_onnx'])
117 changes: 117 additions & 0 deletions mmcv/tensorrt/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import numpy as np
import onnx


def preprocess_onnx(onnx_model):
"""Modify onnx model to match with TensorRT plugins in mmcv.

There are some conflict between onnx node definition and TensorRT limit.
This function perform preprocess on the onnx model to solve the conflicts.
For example, onnx `attribute` is loaded in TensorRT on host and onnx
`input` is loaded on device. The shape inference is performed on host, so
any `input` related to shape (such as `max_output_boxes_per_class` in
NonMaxSuppression) should be transformed to `attribute` before conversion.

Arguments:
onnx_model (onnx.ModelProto): Input onnx model.

Returns:
onnx.ModelProto: Modified onnx model.
"""
graph = onnx_model.graph
nodes = graph.node
initializers = graph.initializer
node_dict = {}
for node in nodes:
node_outputs = node.output
for output in node_outputs:
if len(output) > 0:
node_dict[output] = node

init_dict = {_.name: _ for _ in initializers}

nodes_name_to_remove = set()

def is_node_without_output(name):
for node_name, node in node_dict.items():
if node_name not in nodes_name_to_remove:
if name in node.input:
return False
return True

def mark_nodes_to_remove(name):
node = node_dict[name]
nodes_name_to_remove.add(name)
for input_node_name in node.input:
if is_node_without_output(input_node_name):
mark_nodes_to_remove(input_node_name)

def parse_data(name, typ, default_value=0):
if name in node_dict:
node = node_dict[name]
if node.op_type == 'Constant':
raw_data = node.attribute[0].t.raw_data
else:
mark_nodes_to_remove(name)
return default_value
elif name in init_dict:
raw_data = init_dict[name].raw_data
else:
raise ValueError(f'{name} not found in node or initilizer.')
return np.frombuffer(raw_data, typ).item()

nrof_node = len(nodes)
for idx in range(nrof_node):
node = nodes[idx]
node_attributes = node.attribute
node_inputs = node.input
node_outputs = node.output
node_name = node.name
# process NonMaxSuppression node
if node.op_type == 'NonMaxSuppression':
center_point_box = 0
max_output_boxes_per_class = 1000000
iou_threshold = 0.3
score_threshold = 0.0
offset = 0
for attribute in node_attributes:
if attribute.name == 'center_point_box':
center_point_box = attribute.i
elif attribute.name == 'offset':
offset = attribute.i

if len(node_inputs) >= 3:
max_output_boxes_per_class = parse_data(
node_inputs[2], np.int64, max_output_boxes_per_class)
mark_nodes_to_remove(node_inputs[2])

if len(node_inputs) >= 4:
iou_threshold = parse_data(node_inputs[3], np.float32,
iou_threshold)
mark_nodes_to_remove(node_inputs[3])

if len(node_inputs) >= 5:
score_threshold = parse_data(node_inputs[4], np.float32)
mark_nodes_to_remove(node_inputs[4])

new_node = onnx.helper.make_node(
'NonMaxSuppression',
node_inputs[:2],
node_outputs,
name=node_name,
center_point_box=center_point_box,
max_output_boxes_per_class=max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
offset=offset)

for output in node_outputs:
if output in node_dict:
node_dict[output] = new_node
nodes.insert(idx, new_node)
nodes.remove(node)

for node_name in nodes_name_to_remove:
nodes.remove(node_dict[node_name])

return onnx_model
90 changes: 1 addition & 89 deletions mmcv/tensorrt/tensorrt_utils.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,8 @@
import numpy as np
import onnx
import tensorrt as trt
import torch


def preprocess_onnx(onnx_model):
"""Modify onnx model to match with TensorRT plugins in mmcv.

There are some conflict between onnx node definition and TensorRT limit.
This function perform preprocess on the onnx model to solve the conflicts.
For example, onnx `attribute` is loaded in TensorRT on host and onnx
`input` is loaded on device. The shape inference is performed on host, so
any `input` related to shape (such as `max_output_boxes_per_class` in
NonMaxSuppression) should be transformed to `attribute` before conversion.

Arguments:
onnx_model (onnx.ModelProto): Input onnx model.

Returns:
onnx.ModelProto: Modified onnx model.
"""
graph = onnx_model.graph
nodes = graph.node
initializers = graph.initializer
node_dict = {}
for node in nodes:
node_outputs = node.output
for output in node_outputs:
if len(output) > 0:
node_dict[output] = node

init_dict = {_.name: _ for _ in initializers}

def parse_data(name, typ):
if name in node_dict:
const_node = node_dict[name]
assert const_node.op_type == 'Constant'
raw_data = const_node.attribute[0].t.raw_data
elif name in init_dict:
raw_data = init_dict[name].raw_data
else:
raise ValueError(f'{name} not found in node or initilizer.')
return np.frombuffer(raw_data, typ).item()

nrof_node = len(nodes)
for idx in range(nrof_node):
node = nodes[idx]
node_attributes = node.attribute
node_inputs = node.input
node_outputs = node.output
node_name = node.name
# process NonMaxSuppression node
if node.op_type == 'NonMaxSuppression':
center_point_box = 0
max_output_boxes_per_class = 1000000
iou_threshold = 0.3
score_threshold = 0.0
offset = 0
for attribute in node_attributes:
if attribute.name == 'center_point_box':
center_point_box = attribute.i
elif attribute.name == 'offset':
offset = attribute.i

if len(node_inputs) >= 3:
max_output_boxes_per_class = parse_data(
node_inputs[2], np.int64)

if len(node_inputs) >= 4:
iou_threshold = parse_data(node_inputs[3], np.float32)

if len(node_inputs) >= 5:
score_threshold = parse_data(node_inputs[4], np.float32)

new_node = onnx.helper.make_node(
'NonMaxSuppression',
node_inputs[:2],
node_outputs,
name=node_name,
center_point_box=center_point_box,
max_output_boxes_per_class=max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
offset=offset)

for output in node_outputs:
if output in node_dict:
node_dict[output] = new_node
nodes.insert(idx, new_node)
nodes.remove(node)

return onnx_model
from mmcv.tensorrt.preprocess import preprocess_onnx


def onnx2trt(onnx_model,
Expand Down
10 changes: 8 additions & 2 deletions tests/test_ops/test_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@ def test_batched_nms(self):
from mmcv.ops import batched_nms
results = mmcv.load('./tests/data/batched_nms_data.pkl')

nms_cfg = dict(type='nms', iou_threshold=0.7)
nms_max_num = 100
nms_cfg = dict(
type='nms',
iou_threshold=0.7,
score_threshold=0.5,
max_num=nms_max_num)
boxes, keep = batched_nms(
torch.from_numpy(results['boxes']),
torch.from_numpy(results['scores']),
Expand All @@ -156,7 +161,8 @@ def test_batched_nms(self):

assert torch.equal(keep, seq_keep)
assert torch.equal(boxes, seq_boxes)
assert torch.equal(keep, torch.from_numpy(results['keep']))
assert torch.equal(keep,
torch.from_numpy(results['keep'][:nms_max_num]))

nms_cfg = dict(type='soft_nms', iou_threshold=0.7)
boxes, keep = batched_nms(
Expand Down
Loading