Skip to content

Commit

Permalink
fix satrn onnxruntime batch inference
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Jun 7, 2023
1 parent 439f88b commit 4dbe9b2
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mmdeploy/codebase/mmocr/models/text_recognition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from . import crnn_decoder # noqa: F401,F403
from . import encoder_decoder_recognizer # noqa: F401,F403
from . import lstm_layer # noqa: F401,F403
from . import nrtr_decoder # noqa: F401,F403
from . import sar_decoder # noqa: F401,F403
from . import sar_encoder # noqa: F401,F403
from . import satrn_encoder # noqa: F401,F403
from . import transformer_module # noqa: F401,F403
36 changes: 36 additions & 0 deletions mmdeploy/codebase/mmocr/models/text_recognition/nrtr_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Sequence

import torch

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textrecog.NRTRDecoder._get_source_mask')
def nrtr_decoder___get_source_mask(
self, src_seq: torch.Tensor,
valid_ratios: Sequence[float]) -> torch.Tensor:
"""Generate mask for source sequence.
Args:
src_seq (torch.Tensor): Image sequence. Shape :math:`(N, T, C)`.
valid_ratios (list[float]): The valid ratio of input image. For
example, if the width of the original image is w1 and the width
after padding is w2, then valid_ratio = w1/w2. Source mask is
used to cover the area of the padding region.
Returns:
Tensor or None: Source mask. Shape :math:`(N, T)`. The region of
padding area are False, and the rest are True.
"""

N, T, _ = src_seq.size()
mask = None
if len(valid_ratios) > 0:
mask = src_seq.new_zeros((N, T), device=src_seq.device)
valid_width = min(T, math.ceil(T * valid_ratios[0]))
mask[:, :valid_width] = 1

return mask
42 changes: 42 additions & 0 deletions mmdeploy/codebase/mmocr/models/text_recognition/satrn_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List

from mmocr.structures import TextRecogDataSample
from torch import Tensor

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textrecog.SATRNEncoder.forward')
def satrn_encoder__forward(
self,
feat: Tensor,
data_samples: List[TextRecogDataSample] = None) -> Tensor:
"""Forward propagation of encoder.
Args:
feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
data_samples (list[TextRecogDataSample]): Batch of
TextRecogDataSample, containing `valid_ratio` information.
Defaults to None.
Returns:
Tensor: A tensor of shape :math:`(N, T, D_m)`.
"""
valid_ratio = 1.0
feat = self.position_enc(feat)
n, c, h, w = feat.size()
mask = feat.new_zeros((n, h, w))
valid_width = min(w, math.ceil(w * valid_ratio))
mask[:, :, :valid_width] = 1
mask = mask.view(n, h * w)
feat = feat.view(n, c, h * w)

output = feat.permute(0, 2, 1).contiguous()
for enc_layer in self.layer_stack:
output = enc_layer(output, h, w, mask)
output = self.layer_norm(output)

return output

0 comments on commit 4dbe9b2

Please sign in to comment.