Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue loading custom ONNX model with complex-valued operations in ONNX Runtime (C++) #23341

Open
xefonon opened this issue Jan 13, 2025 · 2 comments
Labels
core runtime issues related to core runtime

Comments

@xefonon
Copy link

xefonon commented Jan 13, 2025

Describe the issue

I have implemented a custom model with complex operations (e.g., ComplexLinear, ApplyComplex) and exported it to ONNX using PyTorch. However, when trying to load the model in ONNX Runtime (C++), I encounter an error related to the schema validation of the custom operation.

Environment:

  • OS: Ubuntu 22.04
  • ONNX version: 1.17.0
  • ONNX Runtime version: 1.20.1
  • PyTorch version: 2.5.1
  • Compiler version: GCC 11.4.0
  • C++ onnxruntime version - 1.14.1

Expected behaviour:

The ONNX model with custom operations should load successfully in ONNX Runtime, and the inference should proceed without issues.

Actual Behavior:

The model fails to load in ONNX Runtime C++ due to the error:
ONNX Schema __main___ApplyComplex_fc_apply_complex_1: failed validating the check: !(it.GetName().empty())

To reproduce

Python code:

import torch
import torch.nn as nn

import onnx
from pathlib import Path
import onnxruntime

# Define a simple model using ComplexLinear

class ApplyComplex(nn.Module):
    
    def __init__(self):
        super(ApplyComplex, self).__init__()

    def forward(self, fr, fi, input, dtype = torch.complex64):
        return (fr(input.real)-fi(input.imag)).type(dtype) \
                + 1j*(fr(input.imag)+fi(input.real)).type(dtype)
class RealToComplex(nn.Module):
    def __init__(self):
        super(RealToComplex, self).__init__()

    def forward(self, x):
        # Otherwise, convert to complex
        return torch.view_as_complex(x)

class ComplexLinear(nn.Module):

    def __init__(self, in_features, out_features):
        super(ComplexLinear, self).__init__()
        self.apply_complex = ApplyComplex()
        self.fc_r = nn.Linear(in_features, out_features)
        self.fc_i = nn.Linear(in_features, out_features)

    def forward(self, input):
        return self.apply_complex(self.fc_r, self.fc_i, input)

class SimpleModel(torch.nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        # self.conv_test1 = ComplexConv2d(1, 10, kernel_size=3)
        # self.conv_test2 = ComplexConv2d(10, 20, kernel_size=3)
        self.fc = ComplexLinear(128, 1)
        self.real_to_cmplx = RealToComplex()


    def forward(self, x):
        x = self.real_to_cmplx(x)
        x = self.fc(x)
        return x

if __name__ == "__main__":


    model = SimpleModel()
    model_path = "simple_model.onnx"
    # Dummy input for model tracing
    dummy_input = torch.randn(1, 128, 2, dtype= torch.float32)

    onnx_registry = torch.onnx.OnnxRegistry()
    onnx_registry.register_op(ComplexLinear, namespace= "ComplexLayers", op_name="ComplexLinear", is_complex=True)

    # Export the model to ONNX
    args = (dummy_input,)
    export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry, dynamic_shapes=True)
    # torch.onnx.dynamo_export(model, *args).save("simple_model.onnx")
    torch.onnx.dynamo_export(model, *args, export_options=export_options).save(model_path)
    print("Model exported to simple_model.onnx")
    onnx_model = onnx.load("simple_model.onnx")
    onnx.checker.check_model(onnx_model, full_check=True)

    for node in onnx_model.graph.node:
        print(f"Node: {node.name}, OpType: {node.op_type}, Domain: {node.domain}")
    
    ort_session = onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"])
    print("Model loaded in onnxruntime")
    ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(dummy_input.numpy())
    # run the model
    ort_inputs = {ort_session.get_inputs()[0].name: ortvalue}
    ort_outs = ort_session.run(None, ort_inputs)
    print(ort_outs)
    print(ort_outs[0].shape)

    print("\n")
    onnx_model = onnx.load("simple_model.onnx")
    print("Inspecting ONNX nodes:")
    for node in onnx_model.graph.node:
        print(f"Name: {node.name}, OpType: {node.op_type}, Domain: {node.domain}")

    for op in onnx_model.opset_import:
        print(f"Domain: {op.domain}, Version: {op.version}")

With the following output:

.../python3.11/site-packages/torch/onnx/_internal/_exporter_legacy.py:116: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
  warnings.warn(
.../python3.11/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  param_schemas = callee.param_schemas()

Applied 2 of general pattern rewrite rules.
Model exported to simple_model.onnx
Node: __main___ComplexLinear_fc_1_1___main___ApplyComplex_fc_apply_complex_1_0, OpType: __main___ApplyComplex_fc_apply_complex_1, Domain: pkg.__main__
Model loaded in onnxruntime
[array([[[-0.7650722 , -0.41634345]]], dtype=float32)]
(1, 1, 2)


Inspecting ONNX nodes:
Name: __main___ComplexLinear_fc_1_1___main___ApplyComplex_fc_apply_complex_1_0, OpType: __main___ApplyComplex_fc_apply_complex_1, Domain: pkg.__main__
Domain: pkg.__main__, Version: 1
Domain: pkg.torch.2.5.1, Version: 1
Domain: , Version: 18
Domain: pkg.onnxscript.torch_lib.common, Version: 1This is the code in c++:

And the C++ snippet:

void NeuralNetInference::LoadNet(const std::string &model_path)
{
    try {
        session_options.SetIntraOpNumThreads(1);
        session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
        session = std::make_unique<Ort::Session>(env, model_path.c_str(), session_options);
        PrintNet();
    } catch (const Ort::Exception &e) {
        std::cerr << "Failed to load model from " << model_path << ": " << e.what() << std::endl;
        throw;
    } catch (const std::exception &e) {
        std::cerr << "Standard exception caught while loading model: " << e.what() << std::endl;
        throw;
    } catch (...) {
        std::cerr << "Unknown exception caught while loading model." << std::endl;
        throw;
    }
}

which outputs the following error:

Failed to load model from /usr/src/repo/resources/vr_complex_cnn_model.onnx: Load model from /usr/src/repo/resources/vr_complex_cnn_model.onnx failed:ONNX Schema __main___ApplyComplex_fc_apply_complex_1: failed validating the check: !(it.GetName().empty())
Exception caught in NeuralNetInference constructor: Load model from /usr/src/repo/resources/vr_complex_cnn_model.onnx failed:ONNX Schema __main___ApplyComplex_fc_apply_complex_1: failed validating the check: !(it.GetName().empty()

Urgency

No response

Platform

Linux

OS Version

Ubuntu 22.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.14.1

ONNX Runtime API

C++

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

@yuslepukhin yuslepukhin added the core runtime issues related to core runtime label Jan 13, 2025
@yuslepukhin
Copy link
Member

The error message comes from ONNX library code where you custom op schema is being declared and finalized. The error message basically states that all inputs and outputs must name names. https://github.com/onnx/onnx/blob/main/onnx/defs/schema.cc#L1461

@xefonon
Copy link
Author

xefonon commented Jan 14, 2025

Thanks @yuslepukhin for replying! It's not clear to me how to name the inputs and outputs in the python code. Could you provide more insight?

edit: I can see that all inputs and outputs are actually named:

Image

Image

Image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core runtime issues related to core runtime
Projects
None yet
Development

No branches or pull requests

2 participants