forked from open-mmlab/mmdetection3d
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add mmocr ncnn support (open-mmlab#53)
* first * fix0 * fix1 * dirty work * wip * add allocator * finally done! * lint * fix lint * better gather * better onnx2ncnn * fix expand * [Fix] NCNN TensorSlice op bugs (open-mmlab#42) * fix custom ops support, fix multiple mark bug, add name mapping * check if the value_info need to be added * remove unnecessary print * add nms implement * two stage split wip * add two stage split * add split retinanet visualize * add two stage split (wip) * finish two stage split * fix lint * move parse string to mmdeploy.utils * add calib data generator * create calib dataset * finish end2end int8 * add split two stage tensorrt visualize * fix tensorslice bugs * fix lint * fix clang-format * remove comments * int param * fix lint Co-authored-by: grimoire <[email protected]> * add two stage ncnn support * remove unused ops * git unused config * remove no_grad, should add in refactor * add ncnn wrapper * fix lint * size return tuple * Resolve grammar error * Fix lint * Trim Trailing Whitespace * fix trim * update wrapper * remove logs * remove * csrc optimize * add ncnn dbnet support * finish crnn support * add comment Co-authored-by: hanrui1sensetime <[email protected]>
- Loading branch information
1 parent
2b98040
commit e73d9fb
Showing
13 changed files
with
169 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
_base_ = ['../_base_/torch2onnx.py'] | ||
codebase = 'mmocr' | ||
|
||
# 'TextDetection' or 'TextRecognition' | ||
task = 'TextDetection' | ||
pytorch2onnx = dict(input_names=['input'], output_names=['output']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
_base_ = ['./base_static.py', '../_base_/backends/ncnn.py'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .lstm_layer import forward_of_bidirectionallstm | ||
|
||
__all__ = ['forward_of_bidirectionallstm'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from mmdeploy.core import FUNCTION_REWRITER | ||
|
||
|
||
@FUNCTION_REWRITER.register_rewriter( | ||
func_name='mmocr.models.textrecog.layers.lstm_layer' | ||
'.BidirectionalLSTM.forward', | ||
backend='ncnn') | ||
def forward_of_bidirectionallstm(ctx, self, input): | ||
self.rnn.batch_first = True | ||
recurrent, _ = self.rnn(input) | ||
self.rnn.batch_first = False | ||
|
||
output = self.embedding(recurrent) | ||
|
||
return output |
3 changes: 3 additions & 0 deletions
3
mmdeploy/mmocr/models/textrecog/recognizer/decoders/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .crnn_decoder import forward_train_of_crnndecoder | ||
|
||
__all__ = ['forward_train_of_crnndecoder'] |
19 changes: 19 additions & 0 deletions
19
mmdeploy/mmocr/models/textrecog/recognizer/decoders/crnn_decoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from mmdeploy.core import FUNCTION_REWRITER | ||
|
||
|
||
@FUNCTION_REWRITER.register_rewriter( | ||
func_name='mmocr.models.textrecog.decoders.CRNNDecoder.forward_train', | ||
backend='ncnn') | ||
def forward_train_of_crnndecoder(ctx, self, feat, out_enc, targets_dict, | ||
img_metas): | ||
assert feat.size(2) == 1, 'feature height must be 1' | ||
if self.rnn_flag: | ||
x = feat.squeeze(2) # [N, C, W] | ||
x = x.permute(0, 2, 1) # [N, W, C] | ||
outputs = self.decoder(x) | ||
else: | ||
x = self.decoder(feat) | ||
x = x.permute(0, 3, 1, 2).contiguous() | ||
n, w, c, h = x.size() | ||
outputs = x.view(n, w, c * h) | ||
return outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
from .getattribute import getattribute_static | ||
from .interpolate import interpolate_static | ||
from .linear import linear_ncnn | ||
from .repeat import repeat_static | ||
from .size import size_of_tensor_static | ||
from .topk import topk_dynamic, topk_static | ||
|
||
__all__ = [ | ||
'getattribute_static', 'interpolate_static', 'repeat_static', | ||
'size_of_tensor_static', 'topk_static', 'topk_dynamic' | ||
'getattribute_static', 'interpolate_static', 'linear_ncnn', | ||
'repeat_static', 'size_of_tensor_static', 'topk_static', 'topk_dynamic' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import Union | ||
|
||
import torch | ||
|
||
from mmdeploy.core import FUNCTION_REWRITER | ||
|
||
|
||
@FUNCTION_REWRITER.register_rewriter( | ||
func_name='torch.nn.functional.linear', backend='ncnn') | ||
def linear_ncnn( | ||
ctx, | ||
input: torch.Tensor, | ||
weight: torch.Tensor, | ||
bias: Union[torch.Tensor, torch.NoneType] = None, | ||
): | ||
origin_func = ctx.origin_func | ||
|
||
dim = input.dim() | ||
|
||
if dim == 2: | ||
return origin_func(input, weight, bias) | ||
else: | ||
out = origin_func(input, weight) | ||
|
||
# permute | ||
out = out.transpose(1, dim - 1) | ||
|
||
# ncnn only support [c, h, w] and [c, 1, 1] broadcast | ||
out_shape = out.shape | ||
batch_size = out_shape[0] | ||
broad_cast_size = out_shape[1] | ||
out = out.reshape([batch_size, broad_cast_size, -1, 1]) | ||
|
||
# add bias | ||
bias = bias.view([1, -1, 1, 1]) | ||
out = out + bias | ||
|
||
# permute back | ||
out = out.reshape(out_shape) | ||
out = out.transpose(1, dim - 1) | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch.onnx.symbolic_helper as sym_help | ||
|
||
from mmdeploy.core import SYMBOLIC_REGISTER | ||
|
||
|
||
@SYMBOLIC_REGISTER.register_symbolic('squeeze', is_pytorch=True) | ||
def squeeze_default(ctx, g, self, dim=None): | ||
if dim is None: | ||
dims = [] | ||
for i, size in enumerate(self.type().sizes()): | ||
if size == 1: | ||
dims.append(i) | ||
else: | ||
dims = [sym_help._get_const(dim, 'i', 'dim')] | ||
return g.op('Squeeze', self, axes_i=dims) |