Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] fix satrn onnxruntime batch inference #2139

Merged
merged 3 commits into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mmdeploy/codebase/mmocr/models/text_recognition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from . import crnn_decoder # noqa: F401,F403
from . import encoder_decoder_recognizer # noqa: F401,F403
from . import lstm_layer # noqa: F401,F403
from . import nrtr_decoder # noqa: F401,F403
from . import sar_decoder # noqa: F401,F403
from . import sar_encoder # noqa: F401,F403
from . import satrn_encoder # noqa: F401,F403
from . import transformer_module # noqa: F401,F403
36 changes: 36 additions & 0 deletions mmdeploy/codebase/mmocr/models/text_recognition/nrtr_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Sequence

import torch

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textrecog.NRTRDecoder._get_source_mask')
def nrtr_decoder___get_source_mask(
irexyc marked this conversation as resolved.
Show resolved Hide resolved
self, src_seq: torch.Tensor,
valid_ratios: Sequence[float]) -> torch.Tensor:
"""Generate mask for source sequence.

Args:
src_seq (torch.Tensor): Image sequence. Shape :math:`(N, T, C)`.
valid_ratios (list[float]): The valid ratio of input image. For
example, if the width of the original image is w1 and the width
after padding is w2, then valid_ratio = w1/w2. Source mask is
used to cover the area of the padding region.

Returns:
Tensor or None: Source mask. Shape :math:`(N, T)`. The region of
padding area are False, and the rest are True.
"""

N, T, _ = src_seq.size()
mask = None
if len(valid_ratios) > 0:
mask = src_seq.new_zeros((N, T), device=src_seq.device)
valid_width = min(T, math.ceil(T * valid_ratios[0]))
mask[:, :valid_width] = 1

return mask
42 changes: 42 additions & 0 deletions mmdeploy/codebase/mmocr/models/text_recognition/satrn_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List

from mmocr.structures import TextRecogDataSample
from torch import Tensor

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textrecog.SATRNEncoder.forward')
def satrn_encoder__forward(
self,
feat: Tensor,
data_samples: List[TextRecogDataSample] = None) -> Tensor:
"""Forward propagation of encoder.

Args:
feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
data_samples (list[TextRecogDataSample]): Batch of
TextRecogDataSample, containing `valid_ratio` information.
Defaults to None.

Returns:
Tensor: A tensor of shape :math:`(N, T, D_m)`.
"""
valid_ratio = 1.0
feat = self.position_enc(feat)
n, c, h, w = feat.size()
mask = feat.new_zeros((n, h, w))
valid_width = min(w, math.ceil(w * valid_ratio))
mask[:, :, :valid_width] = 1
mask = mask.view(n, h * w)
feat = feat.view(n, c, h * w)

output = feat.permute(0, 2, 1).contiguous()
for enc_layer in self.layer_stack:
output = enc_layer(output, h, w, mask)
output = self.layer_norm(output)

return output