Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] SATRN get wrong result when batch size > 1 in inference #2111

Closed
2 of 3 tasks
CescMessi opened this issue May 24, 2023 · 5 comments
Closed
2 of 3 tasks

[Bug] SATRN get wrong result when batch size > 1 in inference #2111

CescMessi opened this issue May 24, 2023 · 5 comments
Assignees

Comments

@CescMessi
Copy link
Contributor

CescMessi commented May 24, 2023

Checklist

  • I have searched related issues but cannot get the expected help.
  • 2. I have read the FAQ documentation but cannot get the expected help.
  • 3. The bug has not been fixed in the latest version.

Describe the bug

I converted the SATRN to onnx by using dynamic config, when inference in batch size=1, result is correct, but when batch size > 1, the results will be same except first image.

convert command:

python tools/deploy.py configs/mmocr/text-recognition/text-recognition_onnxruntime_dynamic.py /code/mmocr/run/config.py epoch_20.pth 0.jpg --work-dir work_dir_ocr --dump-info

inference code:

sess = ort.InferenceSession('work_dir_ocr/end2end.onnx', providers=['CUDAExecutionProvider'])
mean = np.array([123.675,116.28,103.53]).reshape(3,1,1)
std = np.array([58.395,57.12,57.375]).reshape(3,1,1)
img = cv2.imread('a.jpg')
img = cv2.resize(img, (100,32))
inputs = np.concatenate(
    [img[:,:,::-1].transpose(2,0,1)[np.newaxis,:], 
     img[:,:,::-1].transpose(2,0,1)[np.newaxis,:],
     img[:,:,::-1].transpose(2,0,1)[np.newaxis,:]
     ], axis=0)
inputs_n = ((inputs - mean) / std).astype(np.float32)
outputs = sess.run(['output'], {'input':inputs_n})
print(np.argmax(outputs[0], axis=2))

it shows:

[[ 0  0  0  0  0  3 10 10  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0]
 [ 0  1  3 10  0  1 10 10  0  1  0 10  0  0  0  0  0  0  0  0  3  3 10  0
   0]
 [ 0  1  3 10  0  1 10 10  0  1  0 10  0  0  0  0  0  0  0  0  3  3 10  0
   0]]

It seems that there are some variables been fixed when converting to onnx, the log shows some warning:

05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "mmocr_tasks" registry tree. As a workaround, the current "mmocr_tasks" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "mmocr_tasks" registry tree. As a workaround, the current "mmocr_tasks" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "mmocr_tasks" registry tree. As a workaround, the current "mmocr_tasks" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "mmocr_tasks" registry tree. As a workaround, the current "mmocr_tasks" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "mmocr_tasks" registry tree. As a workaround, the current "mmocr_tasks" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:36 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:38 - mmengine - INFO - Start pipeline mmdeploy.apis.pytorch2onnx.torch2onnx in subprocess
05/24 03:36:40 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:40 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "mmocr_tasks" registry tree. As a workaround, the current "mmocr_tasks" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
05/24 03:36:40 - mmengine - WARNING - Failed to search registry with scope "mmocr" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmocr" is a correct scope, or whether the registry is initialized.
Loads checkpoint by local backend from path: /code/mmocr/run_test1/epoch_20.pth
05/24 03:36:41 - mmengine - WARNING - DeprecationWarning: get_onnx_config will be deprecated in the future.
05/24 03:36:41 - mmengine - INFO - Export PyTorch model to ONNX: work_dir_ocr/end2end.onnx.
05/24 03:36:41 - mmengine - WARNING - Can not find torch._C._jit_pass_onnx_autograd_function_process, function rewrite will not be applied
05/24 03:36:41 - mmengine - WARNING - Can not find mmdet.models.dense_heads.DETRHead.forward_single, function rewrite will not be applied
/code/mmocr/mmocr/models/textrecog/encoders/satrn_encoder.py:85: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  valid_width = min(w, math.ceil(w * valid_ratio))
/code/mmocr/mmocr/models/textrecog/encoders/satrn_encoder.py:85: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  valid_width = min(w, math.ceil(w * valid_ratio))
/code/mmocr/mmocr/models/textrecog/decoders/nrtr_decoder.py:137: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  valid_width = min(T, math.ceil(T * valid_ratio))
/code/mmocr/mmocr/models/textrecog/decoders/nrtr_decoder.py:137: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  valid_width = min(T, math.ceil(T * valid_ratio))
/root/workspace/mmdeploy/mmdeploy/codebase/mmocr/models/text_recognition/transformer_module.py:23: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  1.0 / torch.tensor(10000).to(device).pow(
/root/workspace/mmdeploy/mmdeploy/codebase/mmocr/models/text_recognition/transformer_module.py:24: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.tensor(2 * (hid_j // 2) / d_hid)).to(device)
/root/workspace/mmdeploy/mmdeploy/codebase/mmocr/models/text_recognition/transformer_module.py:22: TracerWarning: torch.Tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  denominator = torch.Tensor([
/root/workspace/mmdeploy/mmdeploy/codebase/mmocr/models/text_recognition/transformer_module.py:22: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  denominator = torch.Tensor([
Warning: Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. Constant folding not applied.
Warning: Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. Constant folding not applied.
Warning: Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. Constant folding not applied.
Warning: Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. Constant folding not applied.
....

Reproduction

python tools/deploy.py configs/mmocr/text-recognition/text-recognition_onnxruntime_dynamic.py /code/mmocr/run/config.py epoch_20.pth 0.jpg --work-dir work_dir_ocr --dump-info

config.py:

dictionary = dict(
    type='Dictionary',
    dict_file='/code/mmocr/run/digits.txt',  # 0,1,2,3,4,5,6,7,8,9
    with_padding=True,
    with_unknown=True,
    same_start_end=True,
    with_start=True,
    with_end=True)
model = dict(
    type='SATRN',
    backbone=dict(type='ShallowCNN', input_channels=3, hidden_dim=256),
    encoder=dict(
        type='SATRNEncoder',
        n_layers=6,
        n_head=8,
        d_k=32,
        d_v=32,
        d_model=256,
        n_position=100,
        d_inner=1024,
        dropout=0.1),
    decoder=dict(
        type='NRTRDecoder',
        n_layers=6,
        d_embedding=256,
        n_head=8,
        d_model=256,
        d_inner=1024,
        d_k=32,
        d_v=32,
        module_loss=dict(
            type='CEModuleLoss', flatten=True, ignore_first_char=True),
        dictionary=dict(
            type='Dictionary',
            dict_file='/code/mmocr/run/digits.txt',
            with_padding=True,
            with_unknown=True,
            same_start_end=True,
            with_start=True,
            with_end=True),
        max_seq_len=25,
        postprocessor=dict(type='AttentionPostprocessor')),
    data_preprocessor=dict(
        type='TextRecogDataPreprocessor',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375]))
test_pipeline = [
    dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
    dict(type='Resize', scale=(100, 32), keep_ratio=False),
    dict(type='LoadOCRAnnotations', with_text=True),
    dict(
        type='PackTextRecogInputs',
        meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]

Environment

05/24 08:12:29 - mmengine - INFO -

05/24 08:12:29 - mmengine - INFO - **********Environmental information**********
05/24 08:12:31 - mmengine - INFO - sys.platform: linux
05/24 08:12:31 - mmengine - INFO - Python: 3.8.16 (default, Mar  2 2023, 03:21:46) [GCC 11.2.0]
05/24 08:12:31 - mmengine - INFO - CUDA available: True
05/24 08:12:31 - mmengine - INFO - numpy_random_seed: 2147483648
05/24 08:12:31 - mmengine - INFO - GPU 0,1: Tesla V100-PCIE-32GB
05/24 08:12:31 - mmengine - INFO - CUDA_HOME: /usr/local/cuda
05/24 08:12:31 - mmengine - INFO - NVCC: Cuda compilation tools, release 11.6, V11.6.124
05/24 08:12:31 - mmengine - INFO - GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
05/24 08:12:31 - mmengine - INFO - PyTorch: 1.12.0
05/24 08:12:31 - mmengine - INFO - PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.3
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37
  - CuDNN 8.3.2  (built against CUDA 11.5)
  - Magma 2.5.2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.3, CUDNN_VERSION=8.3.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.12.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,

05/24 08:12:31 - mmengine - INFO - TorchVision: 0.13.0
05/24 08:12:31 - mmengine - INFO - OpenCV: 4.7.0
05/24 08:12:31 - mmengine - INFO - MMEngine: 0.6.0
05/24 08:12:31 - mmengine - INFO - MMCV: 2.0.0rc4
05/24 08:12:31 - mmengine - INFO - MMCV Compiler: GCC 9.3
05/24 08:12:31 - mmengine - INFO - MMCV CUDA Compiler: 11.3
05/24 08:12:31 - mmengine - INFO - MMDeploy: 1.0.0rc3+7f30e42
05/24 08:12:31 - mmengine - INFO -

05/24 08:12:31 - mmengine - INFO - **********Backend information**********
05/24 08:12:31 - mmengine - INFO - tensorrt:    8.2.4.2
05/24 08:12:31 - mmengine - INFO - tensorrt custom ops: Available
05/24 08:12:31 - mmengine - INFO - ONNXRuntime: None
05/24 08:12:31 - mmengine - INFO - ONNXRuntime-gpu:     1.12.1
05/24 08:12:31 - mmengine - INFO - ONNXRuntime custom ops:      Available
05/24 08:12:31 - mmengine - INFO - pplnn:       None
05/24 08:12:31 - mmengine - INFO - ncnn:        None
05/24 08:12:31 - mmengine - INFO - snpe:        None
05/24 08:12:31 - mmengine - INFO - openvino:    None
05/24 08:12:31 - mmengine - INFO - torchscript: 1.12.0
05/24 08:12:31 - mmengine - INFO - torchscript custom ops:      NotAvailable
05/24 08:12:31 - mmengine - INFO - rknn-toolkit:        None
05/24 08:12:31 - mmengine - INFO - rknn-toolkit2:       None
05/24 08:12:31 - mmengine - INFO - ascend:      None
05/24 08:12:31 - mmengine - INFO - coreml:      None
05/24 08:12:31 - mmengine - INFO - tvm: None
05/24 08:12:31 - mmengine - INFO - vacc:        None
05/24 08:12:31 - mmengine - INFO -

05/24 08:12:31 - mmengine - INFO - **********Codebase information**********
05/24 08:12:31 - mmengine - INFO - mmdet:       3.0.0rc6
05/24 08:12:31 - mmengine - INFO - mmseg:       1.0.0rc6
05/24 08:12:31 - mmengine - INFO - mmcls:       1.0.0rc5
05/24 08:12:31 - mmengine - INFO - mmocr:       1.0.0rc6
05/24 08:12:31 - mmengine - INFO - mmedit:      None
05/24 08:12:31 - mmengine - INFO - mmdet3d:     None
05/24 08:12:31 - mmengine - INFO - mmpose:      1.0.0
05/24 08:12:31 - mmengine - INFO - mmrotate:    None
05/24 08:12:31 - mmengine - INFO - mmaction:    None

Error traceback

No response

@CescMessi
Copy link
Contributor Author

Found a discussion about it in #678 , is there any plan to fix it now ?

@irexyc
Copy link
Collaborator

irexyc commented May 25, 2023

I can reproduce it on my machine. I will check it soon.

@irexyc irexyc self-assigned this May 25, 2023
@irexyc
Copy link
Collaborator

irexyc commented May 31, 2023

Hi @CescMessi

You can check this fix #2139

@CescMessi
Copy link
Contributor Author

Hi @CescMessi

You can check this fix #2139

Thank you! I will try it soon.

@CescMessi
Copy link
Contributor Author

Hi @CescMessi

You can check this fix #2139

It works well, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants