Skip to content

Commit

Permalink
merge master and resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
RunningLeon committed May 25, 2021
2 parents f9fc45e + 4d42365 commit 85576ea
Show file tree
Hide file tree
Showing 26 changed files with 1,681 additions and 100 deletions.
8 changes: 4 additions & 4 deletions docs/onnxruntime_op.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
| [SoftNMS](onnxruntime_custom_ops.md#softnms) | Y | N | 1.2.3 |
| [RoIAlign](onnxruntime_custom_ops.md#roialign) | Y | N | 1.2.5 |
| [NMS](onnxruntime_custom_ops.md#nms) | Y | N | 1.2.7 |
| [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | master |
| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | master |
| [cummax](onnxruntime_custom_ops.md#cummax) | Y | N | master |
| [cummin](onnxruntime_custom_ops.md#cummin) | Y | N | master |
| [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | 1.3.1 |
| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | 1.3.4 |
| [cummax](onnxruntime_custom_ops.md#cummax) | Y | N | master |
| [cummin](onnxruntime_custom_ops.md#cummin) | Y | N | master |

## How to build custom operators for ONNX Runtime

Expand Down
42 changes: 42 additions & 0 deletions docs/tensorrt_custom_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@
- [Inputs](#inputs-6)
- [Outputs](#outputs-6)
- [Type Constraints](#type-constraints-6)
- [MMCVInstanceNormalization](#mmcvinstancenormalization)
- [Description](#description-7)
- [Parameters](#parameters-7)
- [Inputs](#inputs-7)
- [Outputs](#outputs-7)
- [Type Constraints](#type-constraints-7)

<!-- TOC -->

Expand Down Expand Up @@ -303,3 +309,39 @@ Returns a namedtuple (`values`, `indices`) where `values` is the cumulative mini
### Type Constraints

- T:tensor(float32, Linear)

## MMCVInstanceNormalization

### Description

Carries out instance normalization as described in the paper https://arxiv.org/abs/1607.08022.

y = scale * (x - mean) / sqrt(variance + epsilon) + B, where mean and variance are computed per instance per channel.

### Parameters

| Type | Parameter | Description |
| ------- | --------- | -------------------------------------------------------------------- |
| `float` | `epsilon` | The epsilon value to use to avoid division by zero. Default is 1e-05 |

### Inputs

<dl>
<dt><tt>input</tt>: T</dt>
<dd>Input data tensor from the previous operator; dimensions for image case are (N x C x H x W), where N is the batch size, C is the number of channels, and H and W are the height and the width of the data. For non image case, the dimensions are in the form of (N x C x D1 x D2 ... Dn), where N is the batch size.</dd>
<dt><tt>scale</tt>: T</dt>
<dd>The input 1-dimensional scale tensor of size C.</dd>
<dt><tt>B</tt>: T</dt>
<dd>The input 1-dimensional bias tensor of size C.</dd>
</dl>

### Outputs

<dl>
<dt><tt>output</tt>: T</dt>
<dd>The output tensor of the same shape as input.</dd>
</dl>

### Type Constraints

- T:tensor(float32, Linear)
19 changes: 10 additions & 9 deletions docs/tensorrt_plugin.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u

## List of TensorRT plugins supported in MMCV

| ONNX Operator | TensorRT Plugin | MMCV Releases |
| :---------------: | :-------------------------------------------------------------: | :-----------: |
| MMCVRoiAlign | [MMCVRoiAlign](./tensorrt_custom_ops.md#mmcvroialign) | 1.2.6 |
| ScatterND | [ScatterND](./tensorrt_custom_ops.md#scatternd) | 1.2.6 |
| NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 |
| MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 |
| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | 1.3.1 |
| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master |
| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master |
| ONNX Operator | TensorRT Plugin | MMCV Releases |
| :-----------------------: | :-----------------------------------------------------------------------------: | :-----------: |
| MMCVRoiAlign | [MMCVRoiAlign](./tensorrt_custom_ops.md#mmcvroialign) | 1.2.6 |
| ScatterND | [ScatterND](./tensorrt_custom_ops.md#scatternd) | 1.2.6 |
| NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 |
| MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 |
| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | 1.3.1 |
| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master |
| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master |
| MMCVInstanceNormalization | [MMCVInstanceNormalization](./tensorrt_custom_ops.md#mmcvinstancenormalization) | master |

Notes

Expand Down
1 change: 0 additions & 1 deletion mmcv/commit_id.py

This file was deleted.

4 changes: 3 additions & 1 deletion mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .bbox import bbox_overlaps
from .border_align import BorderAlign, border_align
from .box_iou_rotated import box_iou_rotated
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
Expand Down Expand Up @@ -48,5 +49,6 @@
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand'
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand',
'BorderAlign', 'border_align'
]
108 changes: 108 additions & 0 deletions mmcv/ops/border_align.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# modified from
# https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/border_align.py

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable

from ..utils import ext_loader

ext_module = ext_loader.load_ext(
'_ext', ['border_align_forward', 'border_align_backward'])


class BorderAlignFunction(Function):

@staticmethod
def symbolic(g, input, boxes, pool_size):
return g.op(
'mmcv::MMCVBorderAlign', input, boxes, pool_size_i=pool_size)

@staticmethod
def forward(ctx, input, boxes, pool_size):
ctx.pool_size = pool_size
ctx.input_shape = input.size()

assert boxes.ndim == 3, 'boxes must be with shape [B, H*W, 4]'
assert boxes.size(2) == 4, \
'the last dimension of boxes must be (x1, y1, x2, y2)'
assert input.size(1) % 4 == 0, \
'the channel for input feature must be divisible by factor 4'

# [B, C//4, H*W, 4]
output_shape = (input.size(0), input.size(1) // 4, boxes.size(1), 4)
output = input.new_zeros(output_shape)
# `argmax_idx` only used for backward
argmax_idx = input.new_zeros(output_shape).to(torch.int)

ext_module.border_align_forward(
input, boxes, output, argmax_idx, pool_size=ctx.pool_size)

ctx.save_for_backward(boxes, argmax_idx)
return output

@staticmethod
@once_differentiable
def backward(ctx, grad_output):
boxes, argmax_idx = ctx.saved_tensors
grad_input = grad_output.new_zeros(ctx.input_shape)
# complex head architecture may cause grad_output uncontiguous
grad_output = grad_output.contiguous()
ext_module.border_align_backward(
grad_output,
boxes,
argmax_idx,
grad_input,
pool_size=ctx.pool_size)
return grad_input, None, None


border_align = BorderAlignFunction.apply


class BorderAlign(nn.Module):
r"""Border align pooling layer.
Applies border_align over the input feature based on predicted bboxes.
The details were described in the paper
`BorderDet: Border Feature for Dense Object Detection
<https://arxiv.org/abs/2007.11056>`_.
For each border line (e.g. top, left, bottom or right) of each box,
border_align does the following:
1. uniformly samples `pool_size`+1 positions on this line, involving \
the start and end points.
2. the corresponding features on these points are computed by \
bilinear interpolation.
3. max pooling over all the `pool_size`+1 positions are used for \
computing pooled feature.
Args:
pool_size (int): number of positions sampled over the boxes' borders
(e.g. top, bottom, left, right).
"""

def __init__(self, pool_size):
super(BorderAlign, self).__init__()
self.pool_size = pool_size

def forward(self, input, boxes):
"""
Args:
input: Features with shape [N,4C,H,W]. Channels ranged in [0,C),
[C,2C), [2C,3C), [3C,4C) represent the top, left, bottom,
right features respectively.
boxes: Boxes with shape [N,H*W,4]. Coordinate format (x1,y1,x2,y2).
Returns:
Tensor: Pooled features with shape [N,C,H*W,4]. The order is
(top,left,bottom,right) for the last dimension.
"""
return border_align(input, boxes, self.pool_size)

def __repr__(self):
s = self.__class__.__name__
s += f'(pool_size={self.pool_size})'
return s
Loading

0 comments on commit 85576ea

Please sign in to comment.