-
Notifications
You must be signed in to change notification settings - Fork 649
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix satrn onnxruntime batch inference
- Loading branch information
Showing
3 changed files
with
80 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 36 additions & 0 deletions
36
mmdeploy/codebase/mmocr/models/text_recognition/nrtr_decoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import math | ||
from typing import Sequence | ||
|
||
import torch | ||
|
||
from mmdeploy.core import FUNCTION_REWRITER | ||
|
||
|
||
@FUNCTION_REWRITER.register_rewriter( | ||
func_name='mmocr.models.textrecog.NRTRDecoder._get_source_mask') | ||
def nrtr_decoder___get_source_mask( | ||
self, src_seq: torch.Tensor, | ||
valid_ratios: Sequence[float]) -> torch.Tensor: | ||
"""Generate mask for source sequence. | ||
Args: | ||
src_seq (torch.Tensor): Image sequence. Shape :math:`(N, T, C)`. | ||
valid_ratios (list[float]): The valid ratio of input image. For | ||
example, if the width of the original image is w1 and the width | ||
after padding is w2, then valid_ratio = w1/w2. Source mask is | ||
used to cover the area of the padding region. | ||
Returns: | ||
Tensor or None: Source mask. Shape :math:`(N, T)`. The region of | ||
padding area are False, and the rest are True. | ||
""" | ||
|
||
N, T, _ = src_seq.size() | ||
mask = None | ||
if len(valid_ratios) > 0: | ||
mask = src_seq.new_zeros((N, T), device=src_seq.device) | ||
valid_width = min(T, math.ceil(T * valid_ratios[0])) | ||
mask[:, :valid_width] = 1 | ||
|
||
return mask |
42 changes: 42 additions & 0 deletions
42
mmdeploy/codebase/mmocr/models/text_recognition/satrn_encoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import math | ||
from typing import List | ||
|
||
from mmocr.structures import TextRecogDataSample | ||
from torch import Tensor | ||
|
||
from mmdeploy.core import FUNCTION_REWRITER | ||
|
||
|
||
@FUNCTION_REWRITER.register_rewriter( | ||
func_name='mmocr.models.textrecog.SATRNEncoder.forward') | ||
def satrn_encoder__forward( | ||
self, | ||
feat: Tensor, | ||
data_samples: List[TextRecogDataSample] = None) -> Tensor: | ||
"""Forward propagation of encoder. | ||
Args: | ||
feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`. | ||
data_samples (list[TextRecogDataSample]): Batch of | ||
TextRecogDataSample, containing `valid_ratio` information. | ||
Defaults to None. | ||
Returns: | ||
Tensor: A tensor of shape :math:`(N, T, D_m)`. | ||
""" | ||
valid_ratio = 1.0 | ||
feat = self.position_enc(feat) | ||
n, c, h, w = feat.size() | ||
mask = feat.new_zeros((n, h, w)) | ||
valid_width = min(w, math.ceil(w * valid_ratio)) | ||
mask[:, :, :valid_width] = 1 | ||
mask = mask.view(n, h * w) | ||
feat = feat.view(n, c, h * w) | ||
|
||
output = feat.permute(0, 2, 1).contiguous() | ||
for enc_layer in self.layer_stack: | ||
output = enc_layer(output, h, w, mask) | ||
output = self.layer_norm(output) | ||
|
||
return output |