Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rtm_det-ins incorrect mask dimensions after ONNX export #2571

Closed
3 tasks done
LuukvandenBent opened this issue Nov 28, 2023 · 4 comments
Closed
3 tasks done

Rtm_det-ins incorrect mask dimensions after ONNX export #2571

LuukvandenBent opened this issue Nov 28, 2023 · 4 comments
Assignees

Comments

@LuukvandenBent
Copy link
Contributor

LuukvandenBent commented Nov 28, 2023

Checklist

  • I have searched related issues but cannot get the expected help.
  • 2. I have read the FAQ documentation but cannot get the expected help.
  • 3. The bug has not been fixed in the latest version.

Describe the bug

Rtm_det-ins incorrect mask dimensions after ONNX export.
Related to: #2334

Reproduction

python3 tools/deploy.py configs/mmdet/instance-seg/instance-seg_rtmdet-ins_onnxruntime_static-640x640.py mmdetection/configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth mmdetection/demo/demo.jpg --show --work-dir work_dir

Environment

Using docker image: [ubuntu20.04-cuda11.8-mmdeploy](https://hub.docker.com/layers/openmmlab/mmdeploy/ubuntu20.04-cuda11.8-mmdeploy/images/sha256-f2b90730c5e1b2b03b775bc592b8375af1daebb1d4ff80f5bff1594073fd0524?context=explore)
Pulled latest commits and rebuild.

11/28 12:14:35 - mmengine - INFO - **********Environmental information**********
11/28 12:14:36 - mmengine - INFO - sys.platform: linux
11/28 12:14:36 - mmengine - INFO - Python: 3.8.10 (default, May 26 2023, 14:05:08) [GCC 9.4.0]
11/28 12:14:36 - mmengine - INFO - CUDA available: True
11/28 12:14:36 - mmengine - INFO - numpy_random_seed: 2147483648
11/28 12:14:36 - mmengine - INFO - GPU 0: NVIDIA GeForce RTX 3060 Laptop GPU
11/28 12:14:36 - mmengine - INFO - CUDA_HOME: /usr/local/cuda
11/28 12:14:36 - mmengine - INFO - NVCC: Cuda compilation tools, release 11.8, V11.8.89
11/28 12:14:36 - mmengine - INFO - GCC: x86_64-linux-gnu-gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
11/28 12:14:36 - mmengine - INFO - PyTorch: 2.0.0+cu118
11/28 12:14:36 - mmengine - INFO - PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.7.3 (Git Hash 6dbeffbae1f23cbbeae17adb7b5b13f1f37c080e)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.8
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
  - CuDNN 8.7
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.8, CUDNN_VERSION=8.7.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.0.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 

11/28 12:14:36 - mmengine - INFO - TorchVision: 0.15.0+cu118
11/28 12:14:36 - mmengine - INFO - OpenCV: 4.5.4
11/28 12:14:36 - mmengine - INFO - MMEngine: 0.8.5
11/28 12:14:36 - mmengine - INFO - MMCV: 2.0.1
11/28 12:14:36 - mmengine - INFO - MMCV Compiler: GCC 9.3
11/28 12:14:36 - mmengine - INFO - MMCV CUDA Compiler: 11.8
11/28 12:14:36 - mmengine - INFO - MMDeploy: 1.3.0+8b19586
11/28 12:14:36 - mmengine - INFO - 

11/28 12:14:36 - mmengine - INFO - **********Backend information**********
11/28 12:14:36 - mmengine - INFO - tensorrt:    8.6.1
11/28 12:14:36 - mmengine - INFO - tensorrt custom ops: Available
11/28 12:14:36 - mmengine - INFO - ONNXRuntime: None
11/28 12:14:36 - mmengine - INFO - ONNXRuntime-gpu:     1.15.1
11/28 12:14:36 - mmengine - INFO - ONNXRuntime custom ops:      Available
11/28 12:14:36 - mmengine - INFO - pplnn:       0.8.1
11/28 12:14:36 - mmengine - INFO - ncnn:        1.0.20230905
11/28 12:14:36 - mmengine - INFO - ncnn custom ops:     Available
11/28 12:14:36 - mmengine - INFO - snpe:        None
11/28 12:14:36 - mmengine - INFO - openvino:    2023.0.2
11/28 12:14:36 - mmengine - INFO - torchscript: 2.0.0+cu118
11/28 12:14:36 - mmengine - INFO - torchscript custom ops:      Available
11/28 12:14:36 - mmengine - INFO - rknn-toolkit:        None
11/28 12:14:36 - mmengine - INFO - rknn-toolkit2:       None
11/28 12:14:36 - mmengine - INFO - ascend:      None
11/28 12:14:36 - mmengine - INFO - coreml:      None
11/28 12:14:36 - mmengine - INFO - tvm: None
11/28 12:14:36 - mmengine - INFO - vacc:        None
11/28 12:14:36 - mmengine - INFO - 

11/28 12:14:36 - mmengine - INFO - **********Codebase information**********
11/28 12:14:36 - mmengine - INFO - mmdet:       3.2.0
11/28 12:14:36 - mmengine - INFO - mmseg:       None
11/28 12:14:36 - mmengine - INFO - mmpretrain:  None
11/28 12:14:36 - mmengine - INFO - mmocr:       None
11/28 12:14:36 - mmengine - INFO - mmagic:      None
11/28 12:14:36 - mmengine - INFO - mmdet3d:     None
11/28 12:14:36 - mmengine - INFO - mmpose:      None
11/28 12:14:36 - mmengine - INFO - mmrotate:    None
11/28 12:14:36 - mmengine - INFO - mmaction:    None
11/28 12:14:36 - mmengine - INFO - mmrazor:     None
11/28 12:14:36 - mmengine - INFO - mmyolo:      None

Error traceback

11/28 12:10:22 - mmengine - INFO - Start pipeline mmdeploy.apis.pytorch2onnx.torch2onnx in subprocess
11/28 12:10:23 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
11/28 12:10:23 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "mmdet_tasks" registry tree. As a workaround, the current "mmdet_tasks" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
Loads checkpoint by http backend from path: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth
The model and loaded state dict do not match exactly

unexpected key in source state_dict: head.fc.weight, head.fc.bias

missing keys in source state_dict: neck.reduce_layers.0.conv.weight, neck.reduce_layers.0.bn.weight, neck.reduce_layers.0.bn.bias, neck.reduce_layers.0.bn.running_mean, neck.reduce_layers.0.bn.running_var, neck.reduce_layers.1.conv.weight, neck.reduce_layers.1.bn.weight, neck.reduce_layers.1.bn.bias, neck.reduce_layers.1.bn.running_mean, neck.reduce_layers.1.bn.running_var, neck.top_down_blocks.0.main_conv.conv.weight, neck.top_down_blocks.0.main_conv.bn.weight, neck.top_down_blocks.0.main_conv.bn.bias, neck.top_down_blocks.0.main_conv.bn.running_mean, neck.top_down_blocks.0.main_conv.bn.running_var, neck.top_down_blocks.0.short_conv.conv.weight, neck.top_down_blocks.0.short_conv.bn.weight, neck.top_down_blocks.0.short_conv.bn.bias, neck.top_down_blocks.0.short_conv.bn.running_mean, neck.top_down_blocks.0.short_conv.bn.running_var, neck.top_down_blocks.0.final_conv.conv.weight, neck.top_down_blocks.0.final_conv.bn.weight, neck.top_down_blocks.0.final_conv.bn.bias, neck.top_down_blocks.0.final_conv.bn.running_mean, neck.top_down_blocks.0.final_conv.bn.running_var, neck.top_down_blocks.0.blocks.0.conv1.conv.weight, neck.top_down_blocks.0.blocks.0.conv1.bn.weight, neck.top_down_blocks.0.blocks.0.conv1.bn.bias, neck.top_down_blocks.0.blocks.0.conv1.bn.running_mean, neck.top_down_blocks.0.blocks.0.conv1.bn.running_var, neck.top_down_blocks.0.blocks.0.conv2.depthwise_conv.conv.weight, neck.top_down_blocks.0.blocks.0.conv2.depthwise_conv.bn.weight, neck.top_down_blocks.0.blocks.0.conv2.depthwise_conv.bn.bias, neck.top_down_blocks.0.blocks.0.conv2.depthwise_conv.bn.running_mean, neck.top_down_blocks.0.blocks.0.conv2.depthwise_conv.bn.running_var, neck.top_down_blocks.0.blocks.0.conv2.pointwise_conv.conv.weight, neck.top_down_blocks.0.blocks.0.conv2.pointwise_conv.bn.weight, neck.top_down_blocks.0.blocks.0.conv2.pointwise_conv.bn.bias, neck.top_down_blocks.0.blocks.0.conv2.pointwise_conv.bn.running_mean, neck.top_down_blocks.0.blocks.0.conv2.pointwise_conv.bn.running_var, neck.top_down_blocks.1.main_conv.conv.weight, neck.top_down_blocks.1.main_conv.bn.weight, neck.top_down_blocks.1.main_conv.bn.bias, neck.top_down_blocks.1.main_conv.bn.running_mean, neck.top_down_blocks.1.main_conv.bn.running_var, neck.top_down_blocks.1.short_conv.conv.weight, neck.top_down_blocks.1.short_conv.bn.weight, neck.top_down_blocks.1.short_conv.bn.bias, neck.top_down_blocks.1.short_conv.bn.running_mean, neck.top_down_blocks.1.short_conv.bn.running_var, neck.top_down_blocks.1.final_conv.conv.weight, neck.top_down_blocks.1.final_conv.bn.weight, neck.top_down_blocks.1.final_conv.bn.bias, neck.top_down_blocks.1.final_conv.bn.running_mean, neck.top_down_blocks.1.final_conv.bn.running_var, neck.top_down_blocks.1.blocks.0.conv1.conv.weight, neck.top_down_blocks.1.blocks.0.conv1.bn.weight, neck.top_down_blocks.1.blocks.0.conv1.bn.bias, neck.top_down_blocks.1.blocks.0.conv1.bn.running_mean, neck.top_down_blocks.1.blocks.0.conv1.bn.running_var, neck.top_down_blocks.1.blocks.0.conv2.depthwise_conv.conv.weight, neck.top_down_blocks.1.blocks.0.conv2.depthwise_conv.bn.weight, neck.top_down_blocks.1.blocks.0.conv2.depthwise_conv.bn.bias, neck.top_down_blocks.1.blocks.0.conv2.depthwise_conv.bn.running_mean, neck.top_down_blocks.1.blocks.0.conv2.depthwise_conv.bn.running_var, neck.top_down_blocks.1.blocks.0.conv2.pointwise_conv.conv.weight, neck.top_down_blocks.1.blocks.0.conv2.pointwise_conv.bn.weight, neck.top_down_blocks.1.blocks.0.conv2.pointwise_conv.bn.bias, neck.top_down_blocks.1.blocks.0.conv2.pointwise_conv.bn.running_mean, neck.top_down_blocks.1.blocks.0.conv2.pointwise_conv.bn.running_var, neck.downsamples.0.conv.weight, neck.downsamples.0.bn.weight, neck.downsamples.0.bn.bias, neck.downsamples.0.bn.running_mean, neck.downsamples.0.bn.running_var, neck.downsamples.1.conv.weight, neck.downsamples.1.bn.weight, neck.downsamples.1.bn.bias, neck.downsamples.1.bn.running_mean, neck.downsamples.1.bn.running_var, neck.bottom_up_blocks.0.main_conv.conv.weight, neck.bottom_up_blocks.0.main_conv.bn.weight, neck.bottom_up_blocks.0.main_conv.bn.bias, neck.bottom_up_blocks.0.main_conv.bn.running_mean, neck.bottom_up_blocks.0.main_conv.bn.running_var, neck.bottom_up_blocks.0.short_conv.conv.weight, neck.bottom_up_blocks.0.short_conv.bn.weight, neck.bottom_up_blocks.0.short_conv.bn.bias, neck.bottom_up_blocks.0.short_conv.bn.running_mean, neck.bottom_up_blocks.0.short_conv.bn.running_var, neck.bottom_up_blocks.0.final_conv.conv.weight, neck.bottom_up_blocks.0.final_conv.bn.weight, neck.bottom_up_blocks.0.final_conv.bn.bias, neck.bottom_up_blocks.0.final_conv.bn.running_mean, neck.bottom_up_blocks.0.final_conv.bn.running_var, neck.bottom_up_blocks.0.blocks.0.conv1.conv.weight, neck.bottom_up_blocks.0.blocks.0.conv1.bn.weight, neck.bottom_up_blocks.0.blocks.0.conv1.bn.bias, neck.bottom_up_blocks.0.blocks.0.conv1.bn.running_mean, neck.bottom_up_blocks.0.blocks.0.conv1.bn.running_var, neck.bottom_up_blocks.0.blocks.0.conv2.depthwise_conv.conv.weight, neck.bottom_up_blocks.0.blocks.0.conv2.depthwise_conv.bn.weight, neck.bottom_up_blocks.0.blocks.0.conv2.depthwise_conv.bn.bias, neck.bottom_up_blocks.0.blocks.0.conv2.depthwise_conv.bn.running_mean, neck.bottom_up_blocks.0.blocks.0.conv2.depthwise_conv.bn.running_var, neck.bottom_up_blocks.0.blocks.0.conv2.pointwise_conv.conv.weight, neck.bottom_up_blocks.0.blocks.0.conv2.pointwise_conv.bn.weight, neck.bottom_up_blocks.0.blocks.0.conv2.pointwise_conv.bn.bias, neck.bottom_up_blocks.0.blocks.0.conv2.pointwise_conv.bn.running_mean, neck.bottom_up_blocks.0.blocks.0.conv2.pointwise_conv.bn.running_var, neck.bottom_up_blocks.1.main_conv.conv.weight, neck.bottom_up_blocks.1.main_conv.bn.weight, neck.bottom_up_blocks.1.main_conv.bn.bias, neck.bottom_up_blocks.1.main_conv.bn.running_mean, neck.bottom_up_blocks.1.main_conv.bn.running_var, neck.bottom_up_blocks.1.short_conv.conv.weight, neck.bottom_up_blocks.1.short_conv.bn.weight, neck.bottom_up_blocks.1.short_conv.bn.bias, neck.bottom_up_blocks.1.short_conv.bn.running_mean, neck.bottom_up_blocks.1.short_conv.bn.running_var, neck.bottom_up_blocks.1.final_conv.conv.weight, neck.bottom_up_blocks.1.final_conv.bn.weight, neck.bottom_up_blocks.1.final_conv.bn.bias, neck.bottom_up_blocks.1.final_conv.bn.running_mean, neck.bottom_up_blocks.1.final_conv.bn.running_var, neck.bottom_up_blocks.1.blocks.0.conv1.conv.weight, neck.bottom_up_blocks.1.blocks.0.conv1.bn.weight, neck.bottom_up_blocks.1.blocks.0.conv1.bn.bias, neck.bottom_up_blocks.1.blocks.0.conv1.bn.running_mean, neck.bottom_up_blocks.1.blocks.0.conv1.bn.running_var, neck.bottom_up_blocks.1.blocks.0.conv2.depthwise_conv.conv.weight, neck.bottom_up_blocks.1.blocks.0.conv2.depthwise_conv.bn.weight, neck.bottom_up_blocks.1.blocks.0.conv2.depthwise_conv.bn.bias, neck.bottom_up_blocks.1.blocks.0.conv2.depthwise_conv.bn.running_mean, neck.bottom_up_blocks.1.blocks.0.conv2.depthwise_conv.bn.running_var, neck.bottom_up_blocks.1.blocks.0.conv2.pointwise_conv.conv.weight, neck.bottom_up_blocks.1.blocks.0.conv2.pointwise_conv.bn.weight, neck.bottom_up_blocks.1.blocks.0.conv2.pointwise_conv.bn.bias, neck.bottom_up_blocks.1.blocks.0.conv2.pointwise_conv.bn.running_mean, neck.bottom_up_blocks.1.blocks.0.conv2.pointwise_conv.bn.running_var, neck.out_convs.0.conv.weight, neck.out_convs.0.bn.weight, neck.out_convs.0.bn.bias, neck.out_convs.0.bn.running_mean, neck.out_convs.0.bn.running_var, neck.out_convs.1.conv.weight, neck.out_convs.1.bn.weight, neck.out_convs.1.bn.bias, neck.out_convs.1.bn.running_mean, neck.out_convs.1.bn.running_var, neck.out_convs.2.conv.weight, neck.out_convs.2.bn.weight, neck.out_convs.2.bn.bias, neck.out_convs.2.bn.running_mean, neck.out_convs.2.bn.running_var, bbox_head.cls_convs.0.0.conv.weight, bbox_head.cls_convs.0.0.bn.weight, bbox_head.cls_convs.0.0.bn.bias, bbox_head.cls_convs.0.0.bn.running_mean, bbox_head.cls_convs.0.0.bn.running_var, bbox_head.cls_convs.0.1.conv.weight, bbox_head.cls_convs.0.1.bn.weight, bbox_head.cls_convs.0.1.bn.bias, bbox_head.cls_convs.0.1.bn.running_mean, bbox_head.cls_convs.0.1.bn.running_var, bbox_head.cls_convs.1.0.conv.weight, bbox_head.cls_convs.1.0.bn.weight, bbox_head.cls_convs.1.0.bn.bias, bbox_head.cls_convs.1.0.bn.running_mean, bbox_head.cls_convs.1.0.bn.running_var, bbox_head.cls_convs.1.1.conv.weight, bbox_head.cls_convs.1.1.bn.weight, bbox_head.cls_convs.1.1.bn.bias, bbox_head.cls_convs.1.1.bn.running_mean, bbox_head.cls_convs.1.1.bn.running_var, bbox_head.cls_convs.2.0.conv.weight, bbox_head.cls_convs.2.0.bn.weight, bbox_head.cls_convs.2.0.bn.bias, bbox_head.cls_convs.2.0.bn.running_mean, bbox_head.cls_convs.2.0.bn.running_var, bbox_head.cls_convs.2.1.conv.weight, bbox_head.cls_convs.2.1.bn.weight, bbox_head.cls_convs.2.1.bn.bias, bbox_head.cls_convs.2.1.bn.running_mean, bbox_head.cls_convs.2.1.bn.running_var, bbox_head.reg_convs.0.0.conv.weight, bbox_head.reg_convs.0.0.bn.weight, bbox_head.reg_convs.0.0.bn.bias, bbox_head.reg_convs.0.0.bn.running_mean, bbox_head.reg_convs.0.0.bn.running_var, bbox_head.reg_convs.0.1.conv.weight, bbox_head.reg_convs.0.1.bn.weight, bbox_head.reg_convs.0.1.bn.bias, bbox_head.reg_convs.0.1.bn.running_mean, bbox_head.reg_convs.0.1.bn.running_var, bbox_head.reg_convs.1.0.conv.weight, bbox_head.reg_convs.1.0.bn.weight, bbox_head.reg_convs.1.0.bn.bias, bbox_head.reg_convs.1.0.bn.running_mean, bbox_head.reg_convs.1.0.bn.running_var, bbox_head.reg_convs.1.1.conv.weight, bbox_head.reg_convs.1.1.bn.weight, bbox_head.reg_convs.1.1.bn.bias, bbox_head.reg_convs.1.1.bn.running_mean, bbox_head.reg_convs.1.1.bn.running_var, bbox_head.reg_convs.2.0.conv.weight, bbox_head.reg_convs.2.0.bn.weight, bbox_head.reg_convs.2.0.bn.bias, bbox_head.reg_convs.2.0.bn.running_mean, bbox_head.reg_convs.2.0.bn.running_var, bbox_head.reg_convs.2.1.conv.weight, bbox_head.reg_convs.2.1.bn.weight, bbox_head.reg_convs.2.1.bn.bias, bbox_head.reg_convs.2.1.bn.running_mean, bbox_head.reg_convs.2.1.bn.running_var, bbox_head.kernel_convs.0.0.conv.weight, bbox_head.kernel_convs.0.0.bn.weight, bbox_head.kernel_convs.0.0.bn.bias, bbox_head.kernel_convs.0.0.bn.running_mean, bbox_head.kernel_convs.0.0.bn.running_var, bbox_head.kernel_convs.0.1.conv.weight, bbox_head.kernel_convs.0.1.bn.weight, bbox_head.kernel_convs.0.1.bn.bias, bbox_head.kernel_convs.0.1.bn.running_mean, bbox_head.kernel_convs.0.1.bn.running_var, bbox_head.kernel_convs.1.0.conv.weight, bbox_head.kernel_convs.1.0.bn.weight, bbox_head.kernel_convs.1.0.bn.bias, bbox_head.kernel_convs.1.0.bn.running_mean, bbox_head.kernel_convs.1.0.bn.running_var, bbox_head.kernel_convs.1.1.conv.weight, bbox_head.kernel_convs.1.1.bn.weight, bbox_head.kernel_convs.1.1.bn.bias, bbox_head.kernel_convs.1.1.bn.running_mean, bbox_head.kernel_convs.1.1.bn.running_var, bbox_head.kernel_convs.2.0.conv.weight, bbox_head.kernel_convs.2.0.bn.weight, bbox_head.kernel_convs.2.0.bn.bias, bbox_head.kernel_convs.2.0.bn.running_mean, bbox_head.kernel_convs.2.0.bn.running_var, bbox_head.kernel_convs.2.1.conv.weight, bbox_head.kernel_convs.2.1.bn.weight, bbox_head.kernel_convs.2.1.bn.bias, bbox_head.kernel_convs.2.1.bn.running_mean, bbox_head.kernel_convs.2.1.bn.running_var, bbox_head.rtm_cls.0.weight, bbox_head.rtm_cls.0.bias, bbox_head.rtm_cls.1.weight, bbox_head.rtm_cls.1.bias, bbox_head.rtm_cls.2.weight, bbox_head.rtm_cls.2.bias, bbox_head.rtm_reg.0.weight, bbox_head.rtm_reg.0.bias, bbox_head.rtm_reg.1.weight, bbox_head.rtm_reg.1.bias, bbox_head.rtm_reg.2.weight, bbox_head.rtm_reg.2.bias, bbox_head.rtm_kernel.0.weight, bbox_head.rtm_kernel.0.bias, bbox_head.rtm_kernel.1.weight, bbox_head.rtm_kernel.1.bias, bbox_head.rtm_kernel.2.weight, bbox_head.rtm_kernel.2.bias, bbox_head.mask_head.fusion_conv.weight, bbox_head.mask_head.fusion_conv.bias, bbox_head.mask_head.stacked_convs.0.conv.weight, bbox_head.mask_head.stacked_convs.0.bn.weight, bbox_head.mask_head.stacked_convs.0.bn.bias, bbox_head.mask_head.stacked_convs.0.bn.running_mean, bbox_head.mask_head.stacked_convs.0.bn.running_var, bbox_head.mask_head.stacked_convs.1.conv.weight, bbox_head.mask_head.stacked_convs.1.bn.weight, bbox_head.mask_head.stacked_convs.1.bn.bias, bbox_head.mask_head.stacked_convs.1.bn.running_mean, bbox_head.mask_head.stacked_convs.1.bn.running_var, bbox_head.mask_head.stacked_convs.2.conv.weight, bbox_head.mask_head.stacked_convs.2.bn.weight, bbox_head.mask_head.stacked_convs.2.bn.bias, bbox_head.mask_head.stacked_convs.2.bn.running_mean, bbox_head.mask_head.stacked_convs.2.bn.running_var, bbox_head.mask_head.stacked_convs.3.conv.weight, bbox_head.mask_head.stacked_convs.3.bn.weight, bbox_head.mask_head.stacked_convs.3.bn.bias, bbox_head.mask_head.stacked_convs.3.bn.running_mean, bbox_head.mask_head.stacked_convs.3.bn.running_var, bbox_head.mask_head.projection.weight, bbox_head.mask_head.projection.bias

11/28 12:10:23 - mmengine - WARNING - DeprecationWarning: get_onnx_config will be deprecated in the future. 
11/28 12:10:23 - mmengine - INFO - Export PyTorch model to ONNX: work_dir/end2end.onnx.
/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/models/detectors/single_stage.py:80: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  img_shape = [int(val) for val in img_shape]
/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/models/detectors/single_stage.py:80: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  img_shape = [int(val) for val in img_shape]
/root/workspace/mmdeploy/mmdeploy/core/optimizers/function_marker.py:160: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  ys_shape = tuple(int(s) for s in ys.shape)
/usr/local/lib/python3.8/dist-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/root/workspace/mmdeploy/mmdeploy/mmcv/ops/nms.py:416: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not is_dynamic_batch(deploy_cfg) and batch_size == 1:
/root/workspace/mmdeploy/mmdeploy/mmcv/ops/nms.py:327: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
/root/workspace/mmdeploy/mmdeploy/mmcv/ops/nms.py:328: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
/root/workspace/mmdeploy/mmdeploy/pytorch/functions/topk.py:28: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  k = torch.tensor(k, device=input.device, dtype=torch.long)
/root/workspace/mmdeploy/mmdeploy/mmcv/ops/nms.py:45: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  score_threshold = float(score_threshold)
/root/workspace/mmdeploy/mmdeploy/mmcv/ops/nms.py:46: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  iou_threshold = float(iou_threshold)
/usr/local/lib/python3.8/dist-packages/mmcv/ops/nms.py:123: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert boxes.size(1) == 4
/usr/local/lib/python3.8/dist-packages/mmcv/ops/nms.py:124: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert boxes.size(0) == scores.size(0)
/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_opset9.py:5589: UserWarning: Exporting aten::index operator of advanced indexing in opset 11 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.
  warnings.warn(
11/28 12:10:30 - mmengine - INFO - Execute onnx optimize passes.
============= Diagnostic Run torch.onnx.export version 2.0.0+cu118 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

11/28 12:10:30 - mmengine - INFO - Finish pipeline mmdeploy.apis.pytorch2onnx.torch2onnx
11/28 12:10:31 - mmengine - INFO - Start pipeline mmdeploy.apis.utils.utils.to_backend in main process
11/28 12:10:31 - mmengine - INFO - Finish pipeline mmdeploy.apis.utils.utils.to_backend
11/28 12:10:31 - mmengine - INFO - visualize onnxruntime model start.
11/28 12:10:33 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "Codebases" registry tree. As a workaround, the current "Codebases" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
11/28 12:10:33 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "mmdet_tasks" registry tree. As a workaround, the current "mmdet_tasks" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
11/28 12:10:33 - mmengine - WARNING - Failed to search registry with scope "mmdet" in the "backend_detectors" registry tree. As a workaround, the current "backend_detectors" registry in "mmdeploy" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmdet" is a correct scope, or whether the registry is initialized.
11/28 12:10:33 - mmengine - INFO - Successfully loaded onnxruntime custom ops from /root/workspace/mmdeploy/mmdeploy/lib/libmmdeploy_onnxruntime_ops.so
2023-11-28:12:10:34 - root - ERROR - The shape of the mask [100] at index 0 does not match the shape of the indexed tensor [99, 640, 640] at index 0
Traceback (most recent call last):
  File "/root/workspace/mmdeploy/mmdeploy/utils/utils.py", line 41, in target_wrapper
    result = target(*args, **kwargs)
  File "/root/workspace/mmdeploy/mmdeploy/apis/visualize.py", line 72, in visualize_model
    result = model.test_step(model_inputs)[0]
  File "/usr/local/lib/python3.8/dist-packages/mmengine/model/base_model/base_model.py", line 145, in test_step
    return self._run_forward(data, mode='predict')  # type: ignore
  File "/usr/local/lib/python3.8/dist-packages/mmengine/model/base_model/base_model.py", line 340, in _run_forward
    results = self(**data, mode=mode)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/deploy/object_detection_model.py", line 299, in forward
    self.postprocessing_results(batch_dets, batch_labels, batch_masks,
  File "/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/deploy/object_detection_model.py", line 189, in postprocessing_results
    outputs = End2EndModel.__clear_outputs(tmp_outputs)
  File "/root/workspace/mmdeploy/mmdeploy/codebase/mmdet/deploy/object_detection_model.py", line 104, in __clear_outputs
    outputs[2][i] = test_outputs[2][i, inds, ...]
IndexError: The shape of the mask [100] at index 0 does not match the shape of the indexed tensor [99, 640, 640] at index 0
11/28 12:10:34 - mmengine - ERROR - tools/deploy.py - create_process - 82 - visualize onnxruntime model failed.
@LuukvandenBent
Copy link
Contributor Author

@RunningLeon this bug should be fixed on main right? Am I missing something?

@LuukvandenBent LuukvandenBent changed the title [Bug] Rtm_det-ins incorrect mask dimensions after ONNX export Nov 28, 2023
@RunningLeon RunningLeon self-assigned this Nov 29, 2023
@RunningLeon
Copy link
Collaborator

RunningLeon commented Nov 29, 2023

hi, there might be somthing wrong for onnxruntime static shape config.
You could change here to bbox_index = bbox_index[:, topk_inds]

bbox_index = bbox_index[:, topk_inds[:-1]]

BTW, you mismatched the model config and checkpoint.
https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth is for classification.
You should use https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_s_8xb32-300e_coco/rtmdet-ins_s_8xb32-300e_coco_20221121_212604-fdc5d7ec.pth to match configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py

@LuukvandenBent
Copy link
Contributor Author

This worked perfectly! Thanks a lot for the quick reply.

@LuukvandenBent
Copy link
Contributor Author

LuukvandenBent commented Nov 29, 2023

I've created a PR for this fix. #2574

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants