Skip to content

Commit

Permalink
[Feature] NMS update (#957)
Browse files Browse the repository at this point in the history
* Add score_threshold and max_num to NMS

* Fix codestyle

* Fix codestyle

* Fix inds in nms

* Update nms docstring

* Move score_threshold and max_num arguments

* Fix args order in docstring

* fix lint of c++ file

* Remove torch.onnx.is_in_onnx_export() and add max_num to batched_nms for separate classes.

* Rewrote max_num handling in NMSop.symbolic

* Added processing max_output_boxes_per_class when exporting to TensorRT

* Added score_threshold and max_num for NMS in test_onnx.py and test_tensorrt.py

* Remove _is_value(max_num)

* fix ci errors with torch==1.3.1

* Update test_batched_nms in test_nms.py

* Added tests for preprocess_onnx

* Moved 'test_tensorrt_preprocess.py' and 'preprocess', updated 'remove_tmp_file'.

* Update mmcv/tensorrt/__init__.py

* Fix segfault torch==1.3.1 (remove onnx.checker.check_model)

* Returned 'onnx.checker.check_model' with torch version check

* Changed torch version from 1.3.1 to 1.4.0

* update version check

* remove check for onnx

Co-authored-by: maningsheng <[email protected]>
  • Loading branch information
SemyonBevzuk and RunningLeon authored May 31, 2021
1 parent 717d157 commit bf2c9fa
Show file tree
Hide file tree
Showing 10 changed files with 289 additions and 123 deletions.
8 changes: 4 additions & 4 deletions docs/onnxruntime_op.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
| [SoftNMS](onnxruntime_custom_ops.md#softnms) | Y | N | 1.2.3 |
| [RoIAlign](onnxruntime_custom_ops.md#roialign) | Y | N | 1.2.5 |
| [NMS](onnxruntime_custom_ops.md#nms) | Y | N | 1.2.7 |
| [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | master |
| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | master |
| [cummax](onnxruntime_custom_ops.md#cummax) | Y | N | master |
| [cummin](onnxruntime_custom_ops.md#cummin) | Y | N | master |
| [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | 1.3.1 |
| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | 1.3.4 |
| [cummax](onnxruntime_custom_ops.md#cummax) | Y | N | master |
| [cummin](onnxruntime_custom_ops.md#cummin) | Y | N | master |

## How to build custom operators for ONNX Runtime

Expand Down
1 change: 1 addition & 0 deletions docs/tensorrt_plugin.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u
| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master |
| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master |
| MMCVInstanceNormalization | [MMCVInstanceNormalization](./tensorrt_custom_ops.md#mmcvinstancenormalization) | master |

Notes

- All plugins listed above are developed on TensorRT-7.2.1.6.Ubuntu-16.04.x86_64-gnu.cuda-10.2.cudnn8.0
Expand Down
53 changes: 44 additions & 9 deletions mmcv/ops/nms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import sys

import numpy as np
import torch
Expand All @@ -15,13 +14,27 @@
class NMSop(torch.autograd.Function):

@staticmethod
def forward(ctx, bboxes, scores, iou_threshold, offset):
def forward(ctx, bboxes, scores, iou_threshold, offset, score_threshold,
max_num):
is_filtering_by_score = score_threshold > 0
if is_filtering_by_score:
valid_mask = scores > score_threshold
bboxes, scores = bboxes[valid_mask], scores[valid_mask]
valid_inds = torch.nonzero(
valid_mask, as_tuple=False).squeeze(dim=1)

inds = ext_module.nms(
bboxes, scores, iou_threshold=float(iou_threshold), offset=offset)

if max_num > 0:
inds = inds[:max_num]
if is_filtering_by_score:
inds = valid_inds[inds]
return inds

@staticmethod
def symbolic(g, bboxes, scores, iou_threshold, offset):
def symbolic(g, bboxes, scores, iou_threshold, offset, score_threshold,
max_num):
from ..onnx import is_custom_op_loaded
has_custom_op = is_custom_op_loaded()
# TensorRT nms plugin is aligned with original nms in ONNXRuntime
Expand All @@ -35,16 +48,28 @@ def symbolic(g, bboxes, scores, iou_threshold, offset):
offset_i=int(offset))
else:
from torch.onnx.symbolic_opset9 import select, squeeze, unsqueeze
from ..onnx.onnx_utils.symbolic_helper import _size_helper

boxes = unsqueeze(g, bboxes, 0)
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
max_output_per_class = g.op(
'Constant',
value_t=torch.tensor([sys.maxsize], dtype=torch.long))

if max_num > 0:
max_num = g.op(
'Constant',
value_t=torch.tensor(max_num, dtype=torch.long))
else:
dim = g.op('Constant', value_t=torch.tensor(0))
max_num = _size_helper(g, bboxes, dim)
max_output_per_class = max_num
iou_threshold = g.op(
'Constant',
value_t=torch.tensor([iou_threshold], dtype=torch.float))
score_threshold = g.op(
'Constant',
value_t=torch.tensor([score_threshold], dtype=torch.float))
nms_out = g.op('NonMaxSuppression', boxes, scores,
max_output_per_class, iou_threshold)
max_output_per_class, iou_threshold,
score_threshold)
return squeeze(
g,
select(
Expand Down Expand Up @@ -90,7 +115,7 @@ def symbolic(g, boxes, scores, iou_threshold, sigma, min_score, method,


@deprecated_api_warning({'iou_thr': 'iou_threshold'})
def nms(boxes, scores, iou_threshold, offset=0):
def nms(boxes, scores, iou_threshold, offset=0, score_threshold=0, max_num=-1):
"""Dispatch to either CPU or GPU NMS implementations.
The input can be either torch tensor or numpy array. GPU NMS will be used
Expand All @@ -102,6 +127,8 @@ def nms(boxes, scores, iou_threshold, offset=0):
scores (torch.Tensor or np.ndarray): scores in shape (N, ).
iou_threshold (float): IoU threshold for NMS.
offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset).
score_threshold (float): score threshold for NMS.
max_num (int): maximum number of boxes after NMS.
Returns:
tuple: kept dets(boxes and scores) and indice, which is always the \
Expand Down Expand Up @@ -141,7 +168,8 @@ def nms(boxes, scores, iou_threshold, offset=0):
}
inds = ext_module.nms(*indata_list, **indata_dict)
else:
inds = NMSop.apply(boxes, scores, iou_threshold, offset)
inds = NMSop.apply(boxes, scores, iou_threshold, offset,
score_threshold, max_num)
dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1)
if is_numpy:
dets = dets.cpu().numpy()
Expand Down Expand Up @@ -285,6 +313,7 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
# Some type of nms would reweight the score, such as SoftNMS
scores = dets[:, 4]
else:
max_num = nms_cfg_.pop('max_num', -1)
total_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
# Some type of nms would reweight the score, such as SoftNMS
scores_after_nms = scores.new_zeros(scores.size())
Expand All @@ -294,10 +323,16 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
total_mask[mask[keep]] = True
scores_after_nms[mask[keep]] = dets[:, -1]
keep = total_mask.nonzero(as_tuple=False).view(-1)

scores, inds = scores_after_nms[keep].sort(descending=True)
keep = keep[inds]
boxes = boxes[keep]

if max_num > 0:
keep = keep[:max_num]
boxes = boxes[:max_num]
scores = scores[:max_num]

return torch.cat([boxes, scores[:, None]], -1), keep


Expand Down
33 changes: 25 additions & 8 deletions mmcv/tensorrt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
# flake8: noqa
from .init_plugins import is_tensorrt_plugin_loaded, load_tensorrt_plugin
from .tensorrt_utils import (TRTWraper, TRTWrapper, load_trt_engine, onnx2trt,
save_trt_engine)
from .preprocess import preprocess_onnx

# load tensorrt plugin lib
load_tensorrt_plugin()

__all__ = [
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWrapper',
'TRTWraper', 'is_tensorrt_plugin_loaded'
]
def is_tensorrt_available():
try:
import tensorrt
del tensorrt
return True
except ModuleNotFoundError:
return False


__all__ = []

if is_tensorrt_available():
from .tensorrt_utils import (TRTWraper, TRTWrapper, load_trt_engine,
onnx2trt, save_trt_engine)

# load tensorrt plugin lib
load_tensorrt_plugin()

__all__.append([
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper',
'TRTWrapper'
])

__all__.append(['is_tensorrt_plugin_loaded', 'preprocess_onnx'])
120 changes: 120 additions & 0 deletions mmcv/tensorrt/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import numpy as np
import onnx


def preprocess_onnx(onnx_model):
"""Modify onnx model to match with TensorRT plugins in mmcv.
There are some conflict between onnx node definition and TensorRT limit.
This function perform preprocess on the onnx model to solve the conflicts.
For example, onnx `attribute` is loaded in TensorRT on host and onnx
`input` is loaded on device. The shape inference is performed on host, so
any `input` related to shape (such as `max_output_boxes_per_class` in
NonMaxSuppression) should be transformed to `attribute` before conversion.
Arguments:
onnx_model (onnx.ModelProto): Input onnx model.
Returns:
onnx.ModelProto: Modified onnx model.
"""
graph = onnx_model.graph
nodes = graph.node
initializers = graph.initializer
node_dict = {}
for node in nodes:
node_outputs = node.output
for output in node_outputs:
if len(output) > 0:
node_dict[output] = node

init_dict = {_.name: _ for _ in initializers}

nodes_name_to_remove = set()

def is_node_without_output(name):
for node_name, node in node_dict.items():
if node_name not in nodes_name_to_remove:
if name in node.input:
return False
return True

def mark_nodes_to_remove(name):
node = node_dict[name]
nodes_name_to_remove.add(name)
for input_node_name in node.input:
if is_node_without_output(input_node_name):
mark_nodes_to_remove(input_node_name)

def parse_data(name, typ, default_value=0):
if name in node_dict:
node = node_dict[name]
if node.op_type == 'Constant':
raw_data = node.attribute[0].t.raw_data
else:
mark_nodes_to_remove(name)
return default_value
elif name in init_dict:
raw_data = init_dict[name].raw_data
else:
raise ValueError(f'{name} not found in node or initilizer.')
return np.frombuffer(raw_data, typ).item()

nrof_node = len(nodes)
for idx in range(nrof_node):
node = nodes[idx]
node_attributes = node.attribute
node_inputs = node.input
node_outputs = node.output
node_name = node.name
# process NonMaxSuppression node
if node.op_type == 'NonMaxSuppression':
center_point_box = 0
max_output_boxes_per_class = 1000000
iou_threshold = 0.3
score_threshold = 0.0
offset = 0
for attribute in node_attributes:
if attribute.name == 'center_point_box':
center_point_box = attribute.i
elif attribute.name == 'offset':
offset = attribute.i

if len(node_inputs) >= 3:
max_output_boxes_per_class = parse_data(
node_inputs[2], np.int64, max_output_boxes_per_class)
mark_nodes_to_remove(node_inputs[2])

if len(node_inputs) >= 4:
iou_threshold = parse_data(node_inputs[3], np.float32,
iou_threshold)
mark_nodes_to_remove(node_inputs[3])

if len(node_inputs) >= 5:
score_threshold = parse_data(node_inputs[4], np.float32)
mark_nodes_to_remove(node_inputs[4])

new_node = onnx.helper.make_node(
'NonMaxSuppression',
node_inputs[:2],
node_outputs,
name=node_name,
center_point_box=center_point_box,
max_output_boxes_per_class=max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
offset=offset)

for output in node_outputs:
if output in node_dict:
node_dict[output] = new_node
nodes.insert(idx, new_node)
nodes.remove(node)
elif node.op_type == 'InstanceNormalization':
# directly change op name
node.op_type = 'MMCVInstanceNormalization'

for node_name in nodes_name_to_remove:
nodes.remove(node_dict[node_name])

return onnx_model
Loading

0 comments on commit bf2c9fa

Please sign in to comment.