Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IIfConditionalOutputLayer inputs must have the same shape. #2846

Closed
dreamkwc opened this issue Apr 3, 2023 · 8 comments
Closed

IIfConditionalOutputLayer inputs must have the same shape. #2846

dreamkwc opened this issue Apr 3, 2023 · 8 comments
Assignees
Labels
Enhancement New feature or request triaged Issue has been triaged by maintainers

Comments

@dreamkwc
Copy link

dreamkwc commented Apr 3, 2023

Description

I want to export the model vitdet in detectron2 to TRT. I first export the model use pytorch.onnx.export and got an error like this:

[04/03/2023-11:10:33] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
[04/03/2023-11:10:34] [TRT] [W] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[04/03/2023-11:10:34] [TRT] [W] onnx2trt_utils.cpp:400: One or more weights outside the range of INT32 was clamped
[04/03/2023-11:10:34] [TRT] [E] [graph.cpp::symbolicExecute::535] Error Code 4: Internal Error (/ScatterND: an IScatterLayer cannot be used to compute a shape tensor)
[04/03/2023-11:10:34] [TRT] [E] ModelImporter.cpp:768: While parsing node number 93 [Pad -> "/Pad_output_0"]:
[04/03/2023-11:10:34] [TRT] [E] ModelImporter.cpp:769: --- Begin node ---
[04/03/2023-11:10:34] [TRT] [E] ModelImporter.cpp:770: input: "/Div_output_0"
input: "/Cast_1_output_0"
input: "/Constant_36_output_0"
output: "/Pad_output_0"
name: "/Pad"
op_type: "Pad"
attribute {
name: "mode"
s: "constant"
type: STRING
}
[04/03/2023-11:10:34] [TRT] [E] ModelImporter.cpp:771: --- End node ---
[04/03/2023-11:10:34] [TRT] [E] ModelImporter.cpp:774: ERROR: ModelImporter.cpp:195 In function parseGraph:
[6] Invalid Node - /Pad
[graph.cpp::symbolicExecute::535] Error Code 4: Internal Error (/ScatterND: an IScatterLayer cannot be used to compute a shape tensor)
In node 93 (parseGraph): INVALID_NODE: Invalid Node - /Pad
[graph.cpp::symbolicExecute::535] Error Code 4: Internal Error (/ScatterND: an IScatterLayer cannot be used to compute a shape tensor)
Traceback (most recent call last):
File "export.py", line 111, in
export_engine()
File "export.py", line 100, in export_engine
raise RuntimeError(f'failed to load ONNX file: {onnx_model}')
RuntimeError: failed to load ONNX file: weights/model.onnx

Then, I used onnxsim to simplify the onnx model. this error was solved but another error appeared.

[04/03/2023-11:12:00] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
[04/03/2023-11:12:01] [TRT] [W] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[04/03/2023-11:12:01] [TRT] [W] onnx2trt_utils.cpp:400: One or more weights outside the range of INT32 was clamped
[04/03/2023-11:12:01] [TRT] [E] /roi_heads/box_pooler/level_poolers.0/If_OutputLayer: IIfConditionalOutputLayer inputs must have the same shape.
[04/03/2023-11:12:01] [TRT] [E] ModelImporter.cpp:768: While parsing node number 1223 [If -> "/roi_heads/box_pooler/level_poolers.0/If_output_0"]:
[04/03/2023-11:12:01] [TRT] [E] ModelImporter.cpp:769: --- Begin node ---
[04/03/2023-11:12:01] [TRT] [E] ModelImporter.cpp:770: input: "/roi_heads/box_pooler/level_poolers.0/Equal_output_0"
output: "/roi_heads/box_pooler/level_poolers.0/If_output_0"
name: "/roi_heads/box_pooler/level_poolers.0/If"
op_type: "If"
attribute {
name: "then_branch"
g {
node {
input: "/roi_heads/box_pooler/level_poolers.0/Gather_output_0"
input: "/roi_heads/box_pooler/level_poolers.0/Constant_3_output_0"
output: "/roi_heads/box_pooler/level_poolers.0/Squeeze_output_0"
name: "/roi_heads/box_pooler/level_poolers.0/Squeeze"
op_type: "Squeeze"
}
name: "torch_jit3"
initializer {
dims: 1
data_type: 7
name: "/roi_heads/box_pooler/level_poolers.0/Constant_3_output_0"
raw_data: "\001\000\000\000\000\000\000\000"
}
output {
name: "/roi_heads/box_pooler/level_poolers.0/Squeeze_output_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "Squeeze/roi_heads/box_pooler/level_poolers.0/Squeeze_output_0_dim_0"
}
}
}
}
}
}
type: GRAPH
}
attribute {
name: "else_branch"
g {
node {
input: "/roi_heads/box_pooler/level_poolers.0/Gather_output_0"
output: "/roi_heads/box_pooler/level_poolers.0/Identity_output_0"
name: "/roi_heads/box_pooler/level_poolers.0/Identity"
op_type: "Identity"
}
name: "torch_jit4"
output {
name: "/roi_heads/box_pooler/level_poolers.0/Identity_output_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "Squeeze/roi_heads/box_pooler/level_poolers.0/Squeeze_output_0_dim_0"
}
dim {
dim_value: 1
}
}
}
}
}
}
type: GRAPH
}
[04/03/2023-11:12:01] [TRT] [E] ModelImporter.cpp:771: --- End node ---
[04/03/2023-11:12:01] [TRT] [E] ModelImporter.cpp:774: ERROR: ModelImporter.cpp:195 In function parseGraph:
[6] Invalid Node - /roi_heads/box_pooler/level_poolers.0/If
/roi_heads/box_pooler/level_poolers.0/If_OutputLayer: IIfConditionalOutputLayer inputs must have the same shape.
In node 1223 (parseGraph): INVALID_NODE: Invalid Node - /roi_heads/box_pooler/level_poolers.0/If
/roi_heads/box_pooler/level_poolers.0/If_OutputLayer: IIfConditionalOutputLayer inputs must have the same shape.
Traceback (most recent call last):
File "export.py", line 111, in
export_engine()
File "export.py", line 100, in export_engine
raise RuntimeError(f'failed to load ONNX file: {onnx_model}')
RuntimeError: failed to load ONNX file: weights/model_sim.onnx

I don't know what I should do. Any suggestion?

Environment

I used the docker image nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04.
For tensorrt, I used the pip install method.
pip install tensorrt
And, three other nvidia packages were installed.
nvidia-cublas-cu12: 12.1.0.26
nvidia-cuda-runtime-cu12: 12.1.55
nvidia-cudnn-cu12: 8.8.1.3

TensorRT Version: 8.6.0
NVIDIA GPU: 3080
NVIDIA Driver Version:525.105.17
CUDA Version: 11.7.1
CUDNN Version: cudnn8
Operating System: ubuntu20.04
Python Version (if applicable): Python 3.8.10
Tensorflow Version (if applicable):
PyTorch Version (if applicable): 1.13.1+cu117
Baremetal or Container (if so, version):

Other packages:
detectron2: 0.6
onnx: 1.13.1
onnxruntime-gpu: 1.14.1
onnxsim: 0.4.17

Steps To Reproduce

export to onnx. The model weight can download from detectron2 github. I used the COCO Mask R-CNN ViTDet, ViT-B

def export_onnx():
    import torch
    from detectron2 import model_zoo
    from detectron2.config import instantiate
    from detectron2.checkpoint import DetectionCheckpointer
    from detectron2.modeling import GeneralizedRCNN
    from detectron2.utils.file_io import PathManager
    from detectron2.export import TracingAdapter
    import onnx
    import onnxsim
    import numpy as np
    device = 'cuda:0'
    weight = os.path.join("weights", "model_final_61ccd1.pkl")

    model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model
    model = instantiate(model).to(torch.device(device))
    DetectionCheckpointer(model).load(weight)
    model.eval()


    batch_size = 1
    image = np.ones([3, 1024, 1024], dtype=np.float32)
    image = torch.from_numpy(image.copy()).float().to(torch.device(device))
    inputs = [{"image": image}]

    if not os.path.exists(os.path.join("weights", "model.onnx")):
        if isinstance(model, GeneralizedRCNN):
            def inference(model, inputs):
                # use do_postprocess=False so it returns ROI mask
                with torch.no_grad():
                    inst = model.inference(inputs, do_postprocess=False)[0]
                return [{"instances": inst}]
        else:
            inference = None  # assume that we just call the model directly

        traceable_model = TracingAdapter(model, inputs, inference)
        with torch.no_grad():
            with PathManager.open(os.path.join("weights", "model.onnx"), "wb") as f:
                torch.onnx.export(traceable_model, (image,), f, opset_version=16)

    model_onnx = onnx.load(os.path.join("weights", "model.onnx"))
    model_onnx, check = onnxsim.simplify(model_onnx)
    assert check, 'Simplified ONNX model could not be validated'
    onnx.save(model_onnx, os.path.join("weights", "model_sim.onnx"))

export to engine

def export_engine():
    import tensorrt as trt
    import os
    onnx_model = os.path.join("weights", "model_sim.onnx")
    egine_model = os.path.join("weights", "model.engine")
    # To create a builder, you must first create a logger.
    logger = trt.Logger(trt.Logger.WARNING)
    # create a builder:
    builder = trt.Builder(logger)

    # Creating a Network Definition in Python
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

    # Importing a Model Using the ONNX Parser
    parser = trt.OnnxParser(network, logger) # create an ONNX parser
    success = parser.parse_from_file(onnx_model)
    for idx in range(parser.num_errors):
        print(parser.get_error(idx))
    if not success:
        raise RuntimeError(f'failed to load ONNX file: {onnx_model}')

    # Building an Engine
    config = builder.create_builder_config()
    config.max_workspace_size = 1 << 30
    serialized_engine = builder.build_serialized_network(network, config)
    with open(egine_model, 'wb') as f:
        f.write(serialized_engine)
@zerollzeng
Copy link
Collaborator

/roi_heads/box_pooler/level_poolers.0/If_OutputLayer: IIfConditionalOutputLayer inputs must have the same shape.

Currently we only support this mode. @jhalakp-nvidia do we have plan to improve it? also I didn't see this document in our api doc, I think we can improve it.

@zerollzeng zerollzeng self-assigned this Apr 3, 2023
@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Apr 3, 2023
@ttyio ttyio added the Enhancement New feature or request label Jul 18, 2023
@ttyio
Copy link
Collaborator

ttyio commented Aug 15, 2023

Sorry we have no plan to support if condition with different branch output different shape so far.

@ttyio ttyio closed this as completed Aug 15, 2023
@caruofc
Copy link

caruofc commented Sep 5, 2023

Hi, @dreamkwc, I am facing the same issue. Could you figure out which "if else" branch in the detectron2 source code is causing this issue? Thanks,

@cbenitez81
Copy link

@caruofc I just faced the same issue exporting an onnx model created with anomalib , the squeeze function in pytorch implies an if statement, meaning remove dimension if that shape element is equal to 1. I modified the code just to index for that case to [:,0,:] which basically squeezes that. Hope that is similar for detectron2

@caruofc
Copy link

caruofc commented Sep 16, 2023

@cbenitez81, Thank you so much for your reply. Your solution makes sense. It would be a great help if you could assist me in identifying the squeeze block where I need to make the changes in the detectron2 source code. I tried to put breakpoints in every squeeze function call in the detectron2 code but none of them were hit when I tried to create the ONNX from the weight file .pth. Not sure how to find the code segment that needs modification. Please help. Thanks

@cbenitez81
Copy link

cbenitez81 commented Sep 27, 2023

@caruofc sorry for the delay and i haven't tested this but from your message you see the following [04/03/2023-11:12:01] [TRT] [E] /roi_heads/box_pooler/level_poolers.0/If_OutputLayer: IIfConditionalOutputLayer inputs must have the same shape. So the problem seems to be in roi_heads. Looking at this link in detectron there is a squeeze inside that object. Probably changing that to fg_selection_mask.nonzero()[:,0,...] would make the trick

@caruofc
Copy link

caruofc commented Sep 28, 2023

@cbenitez81, I finally could solve my issues. This link https://github.com/NVIDIA/TensorRT/tree/release/8.6/samples/python/detectron2 describes how to get around those with a sample mask rcnn model. Basically, I had to modify the the "create_onnx.py" sample script to create NMS node for EfficientNMS_TRT plugin and replace the output, create PyramidROIAlign_TRT plugin to replace the ROIAligns, fold the constants and modify the Reshape nodes of my onnx model. Once the onnx model is converted using the modified "create_onnx.py", I could generate the engine file without any issues.

Thank you for your help and providing me with your valuable feedback.
Appreciate it a lot.

@justanhduc
Copy link

hey @caruofc, could u please let me know in details what u did?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Enhancement New feature or request triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

6 participants