You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have implemented a custom model with complex operations (e.g., ComplexLinear, ApplyComplex) and exported it to ONNX using PyTorch. However, when trying to load the model in ONNX Runtime (C++), I encounter an error related to the schema validation of the custom operation.
Environment:
OS: Ubuntu 22.04
ONNX version: 1.17.0
ONNX Runtime version: 1.20.1
PyTorch version: 2.5.1
Compiler version: GCC 11.4.0
C++ onnxruntime version - 1.14.1
Expected behaviour:
The ONNX model with custom operations should load successfully in ONNX Runtime, and the inference should proceed without issues.
Actual Behavior:
The model fails to load in ONNX Runtime C++ due to the error: ONNX Schema __main___ApplyComplex_fc_apply_complex_1: failed validating the check: !(it.GetName().empty())
To reproduce
Python code:
import torch
import torch.nn as nn
import onnx
from pathlib import Path
import onnxruntime
# Define a simple model using ComplexLinear
class ApplyComplex(nn.Module):
def __init__(self):
super(ApplyComplex, self).__init__()
def forward(self, fr, fi, input, dtype = torch.complex64):
return (fr(input.real)-fi(input.imag)).type(dtype) \
+ 1j*(fr(input.imag)+fi(input.real)).type(dtype)
class RealToComplex(nn.Module):
def __init__(self):
super(RealToComplex, self).__init__()
def forward(self, x):
# Otherwise, convert to complex
return torch.view_as_complex(x)
class ComplexLinear(nn.Module):
def __init__(self, in_features, out_features):
super(ComplexLinear, self).__init__()
self.apply_complex = ApplyComplex()
self.fc_r = nn.Linear(in_features, out_features)
self.fc_i = nn.Linear(in_features, out_features)
def forward(self, input):
return self.apply_complex(self.fc_r, self.fc_i, input)
class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
# self.conv_test1 = ComplexConv2d(1, 10, kernel_size=3)
# self.conv_test2 = ComplexConv2d(10, 20, kernel_size=3)
self.fc = ComplexLinear(128, 1)
self.real_to_cmplx = RealToComplex()
def forward(self, x):
x = self.real_to_cmplx(x)
x = self.fc(x)
return x
if __name__ == "__main__":
model = SimpleModel()
model_path = "simple_model.onnx"
# Dummy input for model tracing
dummy_input = torch.randn(1, 128, 2, dtype= torch.float32)
onnx_registry = torch.onnx.OnnxRegistry()
onnx_registry.register_op(ComplexLinear, namespace= "ComplexLayers", op_name="ComplexLinear", is_complex=True)
# Export the model to ONNX
args = (dummy_input,)
export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry, dynamic_shapes=True)
# torch.onnx.dynamo_export(model, *args).save("simple_model.onnx")
torch.onnx.dynamo_export(model, *args, export_options=export_options).save(model_path)
print("Model exported to simple_model.onnx")
onnx_model = onnx.load("simple_model.onnx")
onnx.checker.check_model(onnx_model, full_check=True)
for node in onnx_model.graph.node:
print(f"Node: {node.name}, OpType: {node.op_type}, Domain: {node.domain}")
ort_session = onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"])
print("Model loaded in onnxruntime")
ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(dummy_input.numpy())
# run the model
ort_inputs = {ort_session.get_inputs()[0].name: ortvalue}
ort_outs = ort_session.run(None, ort_inputs)
print(ort_outs)
print(ort_outs[0].shape)
print("\n")
onnx_model = onnx.load("simple_model.onnx")
print("Inspecting ONNX nodes:")
for node in onnx_model.graph.node:
print(f"Name: {node.name}, OpType: {node.op_type}, Domain: {node.domain}")
for op in onnx_model.opset_import:
print(f"Domain: {op.domain}, Version: {op.version}")
With the following output:
.../python3.11/site-packages/torch/onnx/_internal/_exporter_legacy.py:116: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
warnings.warn(
.../python3.11/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
param_schemas = callee.param_schemas()
Applied 2 of general pattern rewrite rules.
Model exported to simple_model.onnx
Node: __main___ComplexLinear_fc_1_1___main___ApplyComplex_fc_apply_complex_1_0, OpType: __main___ApplyComplex_fc_apply_complex_1, Domain: pkg.__main__
Model loaded in onnxruntime
[array([[[-0.7650722 , -0.41634345]]], dtype=float32)]
(1, 1, 2)
Inspecting ONNX nodes:
Name: __main___ComplexLinear_fc_1_1___main___ApplyComplex_fc_apply_complex_1_0, OpType: __main___ApplyComplex_fc_apply_complex_1, Domain: pkg.__main__
Domain: pkg.__main__, Version: 1
Domain: pkg.torch.2.5.1, Version: 1
Domain: , Version: 18
Domain: pkg.onnxscript.torch_lib.common, Version: 1This is the code in c++:
Failed to load model from /usr/src/repo/resources/vr_complex_cnn_model.onnx: Load model from /usr/src/repo/resources/vr_complex_cnn_model.onnx failed:ONNX Schema __main___ApplyComplex_fc_apply_complex_1: failed validating the check: !(it.GetName().empty())
Exception caught in NeuralNetInference constructor: Load model from /usr/src/repo/resources/vr_complex_cnn_model.onnx failed:ONNX Schema __main___ApplyComplex_fc_apply_complex_1: failed validating the check: !(it.GetName().empty()
Urgency
No response
Platform
Linux
OS Version
Ubuntu 22.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.14.1
ONNX Runtime API
C++
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
No response
The text was updated successfully, but these errors were encountered:
Describe the issue
I have implemented a custom model with complex operations (e.g., ComplexLinear, ApplyComplex) and exported it to ONNX using PyTorch. However, when trying to load the model in ONNX Runtime (C++), I encounter an error related to the schema validation of the custom operation.
Environment:
Expected behaviour:
The ONNX model with custom operations should load successfully in ONNX Runtime, and the inference should proceed without issues.
Actual Behavior:
The model fails to load in ONNX Runtime C++ due to the error:
ONNX Schema __main___ApplyComplex_fc_apply_complex_1: failed validating the check: !(it.GetName().empty())
To reproduce
Python code:
With the following output:
And the C++ snippet:
which outputs the following error:
Urgency
No response
Platform
Linux
OS Version
Ubuntu 22.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.14.1
ONNX Runtime API
C++
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
No response
The text was updated successfully, but these errors were encountered: