-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Softmax-13 should use the opset18 function when opset18 is used #16438
Comments
Looks like onnxruntime/onnxruntime/core/graph/graph.cc Line 599 in 04dbdc9
requested_opset_version https://github.com/onnx/onnx/blob/e9871305d4421ac9836e5fd2be2dc89bc0bb9dcf/onnx/defs/schema.cc#L729
|
Good catch. There are some related issues, listed below. (i) Since ONNX now allows specifying the target ONNX opset for any function body, we need to add this target opset version to all function-ops regsitered in onnxruntime (usually in the MSDOMAIN). For example, this Gelu function def should specify the target opset 13. (It is partially buried here, but we need to pass this as an explicit parameter to SetContextDependentFunctionBodyBuilder). (ii) SetContextDependentFunctionBodyBuilder has a default value for the opset. But the default behavior, which is to use the op's own opset-version, works only for ops in ONNX standard domain. For an op in MSDOMAIN opset version 1, using ONNX opset version 1 as the default value doesn't make sense. So, we should either eliminate the use of the default-value, or use it only for ONNX standard domain, throwing an error/warning for other domains. (iii) We should do something similar for context-independent functions |
That's correct. And the fix should be similar to that used for context-independent functions a few lines below onnxruntime/onnxruntime/core/graph/graph.cc Line 602 in 04dbdc9
which requires the fixes discussed above |
Yet another fix required in ONNX is that BuildContextDependentFunction invokes validation, but returns true even if validation fails. |
SummaryONNX Runtime raises To recreate this report, use CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_opinfo__log_softmax_with_dtype_cpu_float16 To reproduceimport google.protobuf.text_format
import numpy as np
from numpy import array, float16, float32, float64, int32, int64
import onnx
import onnxruntime as ort
# Run n times
N = 1
onnx_model_text = """
ir_version: 8
producer_name: "pytorch"
producer_version: "2.1.0"
graph {
node {
input: "input_0"
output: "_val_1"
name: "aten_special_log_softmax_0"
op_type: "aten_special_log_softmax"
attribute {
name: "dim"
i: 2
type: INT
}
attribute {
name: "dtype"
i: 11
type: INT
}
doc_string: ""
domain: "pkg.onnxscript.torch_lib"
}
name: "torch_jit"
input {
name: "input_0"
type {
tensor_type {
elem_type: 10
shape {
dim {
dim_value: 5
}
dim {
dim_value: 10
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "_val_1"
type {
tensor_type {
elem_type: 11
shape {
dim {
dim_value: 5
}
dim {
dim_value: 10
}
dim {
dim_value: 5
}
}
}
}
}
}
opset_import {
domain: "pkg.onnxscript.torch_lib"
version: 1
}
opset_import {
domain: ""
version: 18
}
functions {
name: "aten_special_log_softmax"
input: "self"
output: "result_8"
attribute: "dim"
node {
input: "self"
output: "tmp"
name: "n0"
op_type: "Shape"
domain: ""
}
node {
input: "tmp"
output: "tmp_0"
name: "n1"
op_type: "Size"
domain: ""
}
node {
output: "int64_0"
name: "n2"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 7
int64_data: 0
name: "int64_0"
}
type: TENSOR
}
domain: ""
}
node {
input: "int64_0"
input: "tmp_0"
output: "int64_0_cast"
name: "n3"
op_type: "CastLike"
domain: ""
}
node {
input: "tmp_0"
input: "int64_0_cast"
output: "self_is_scalar"
name: "n4"
op_type: "Equal"
domain: ""
}
node {
input: "self_is_scalar"
output: "self_4"
name: "n5"
op_type: "If"
attribute {
name: "then_branch"
g {
node {
output: "tmp_1"
name: "n0"
op_type: "Constant"
attribute {
name: "value_ints"
ints: 0
type: INTS
}
domain: ""
}
node {
input: "self"
input: "tmp_1"
output: "self_2"
name: "n1"
op_type: "Unsqueeze"
domain: ""
}
name: "thenGraph_8"
output {
name: "self_2"
type {
}
}
}
type: GRAPH
}
attribute {
name: "else_branch"
g {
node {
input: "self"
output: "self_3"
name: "n0"
op_type: "Identity"
domain: ""
}
name: "elseGraph_8"
output {
name: "self_3"
type {
}
}
}
type: GRAPH
}
domain: ""
}
node {
input: "self_4"
output: "result"
name: "n6"
op_type: "LogSoftmax"
attribute {
name: "axis"
type: INT
ref_attr_name: "dim"
}
domain: ""
}
node {
input: "result"
output: "result_5"
name: "n7"
op_type: "Cast"
attribute {
name: "to"
type: INT
ref_attr_name: "dtype"
}
domain: ""
}
node {
input: "self_is_scalar"
output: "result_8"
name: "n8"
op_type: "If"
attribute {
name: "then_branch"
g {
node {
input: "result_5"
output: "result_6"
name: "n0"
op_type: "Squeeze"
domain: ""
}
name: "thenGraph_12"
output {
name: "result_6"
type {
}
}
}
type: GRAPH
}
attribute {
name: "else_branch"
g {
node {
input: "result_5"
output: "result_7"
name: "n0"
op_type: "Identity"
domain: ""
}
name: "elseGraph_12"
output {
name: "result_7"
type {
}
}
}
type: GRAPH
}
domain: ""
}
doc_string: "special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"
opset_import {
domain: ""
version: 18
}
domain: "pkg.onnxscript.torch_lib"
attribute_proto {
name: "dtype"
i: 1
type: INT
}
}
"""
ort_inputs = {'input_0': array([[[ 3.9902e+00, 3.7539e+00, -4.5352e+00, -2.1875e+00,
8.2812e+00],
[-8.7031e+00, -7.1992e+00, 7.0312e-01, 8.1797e+00,
4.9297e+00],
[ 7.6562e+00, 3.4023e+00, 8.7891e-03, -2.6367e-02,
3.4180e+00],
[-4.0352e+00, 9.2285e-01, 6.7773e+00, 7.2148e+00,
4.1836e+00],
[-2.8301e+00, -5.4766e+00, -2.5938e+00, 4.8789e+00,
-7.5859e+00],
[-7.2344e+00, 8.4141e+00, -2.5488e-01, -6.6367e+00,
7.5781e+00],
[-1.8369e+00, 2.3730e+00, -5.0000e+00, 4.0508e+00,
6.3828e+00],
[ 2.2070e+00, -7.0312e-02, 6.1602e+00, -6.4062e+00,
6.3633e+00],
[-2.6797e+00, 6.5742e+00, -6.0391e+00, -1.2832e+00,
4.5703e-01],
[-2.0469e+00, 8.0938e+00, -7.1016e+00, -2.2930e+00,
-7.3555e+00]],
[[-5.6680e+00, -4.6758e+00, 3.5156e-01, 1.3711e+00,
-8.8750e+00],
[ 7.1367e+00, -8.4375e+00, 7.5234e+00, 7.3672e+00,
-4.4297e+00],
[ 3.0156e+00, 1.1250e+00, 8.8125e+00, -3.3125e+00,
4.1406e+00],
[ 5.4492e-01, 1.2129e+00, 4.3750e+00, -3.7969e+00,
-5.5625e+00],
[ 5.9219e+00, -5.3281e+00, -6.4688e+00, -5.6797e+00,
6.7852e+00],
[ 4.2969e+00, -6.9766e+00, -6.1523e-02, -8.6484e+00,
2.3730e-01],
[-7.8203e+00, -7.2422e+00, 7.3750e+00, -2.1523e+00,
-8.3496e-01],
[ 2.8125e-01, 4.1309e-01, -4.5859e+00, -5.4297e+00,
5.0352e+00],
[ 6.3906e+00, -1.9336e+00, -8.1406e+00, -2.9961e+00,
7.6562e+00],
[-2.6289e+00, 8.6641e+00, 4.7969e+00, -5.6250e-01,
6.4844e+00]],
[[-3.6211e+00, 8.2812e+00, 4.0508e+00, -3.3574e+00,
6.7500e+00],
[ 3.5156e-02, 1.9072e+00, -4.5859e+00, -2.2676e+00,
-5.5117e+00],
[-1.3535e+00, -2.0215e+00, -5.5547e+00, -7.1875e+00,
1.2305e-01],
[ 7.5312e+00, 4.8594e+00, -1.1689e+00, 4.0430e+00,
6.0625e+00],
[ 5.0469e+00, -1.4150e+00, -2.4785e+00, 2.1094e+00,
-4.0508e+00],
[-3.5156e-02, -8.1641e+00, 7.9023e+00, 6.4414e+00,
-4.7461e+00],
[ 3.5684e+00, -6.9766e+00, -2.4258e+00, 5.7500e+00,
1.4941e+00],
[-4.2539e+00, 3.0762e-01, -4.3945e+00, 1.3975e+00,
7.3672e+00],
[ 6.1328e+00, 2.1270e+00, 7.4707e-01, 6.9883e+00,
-3.9297e+00],
[-3.0156e+00, -4.1133e+00, 5.0352e+00, 8.3672e+00,
-1.0547e-01]],
[[-3.3320e+00, -2.0117e+00, 2.3730e-01, 5.4414e+00,
4.8789e+00],
[ 2.9805e+00, 7.0039e+00, 8.4141e+00, 2.9004e+00,
5.6172e+00],
[ 4.4375e+00, 8.9141e+00, 3.0508e+00, 4.1484e+00,
2.0215e+00],
[-7.3398e+00, 1.9688e+00, 3.3750e+00, 3.3047e+00,
2.4785e+00],
[ 8.4141e+00, -3.4453e+00, -3.4883e+00, -4.5703e-01,
6.5391e+00],
[ 2.2676e+00, 3.1465e+00, -2.7695e+00, 4.0000e+00,
3.6035e+00],
[-3.3047e+00, -6.2148e+00, 5.9414e+00, 1.9688e+00,
7.1875e+00],
[ 2.0469e+00, -2.9961e+00, -3.1719e+00, 8.5625e+00,
7.2578e+00],
[ 3.2773e+00, 5.8984e+00, 5.3359e+00, 4.1406e+00,
2.0391e+00],
[ 6.0742e+00, 9.8438e-01, 7.2852e+00, 4.3594e+00,
-5.1250e+00]],
[[-4.0508e+00, 2.3477e+00, 4.4727e+00, 7.0859e+00,
3.3828e+00],
[ 3.3926e+00, 9.8438e-01, 5.3789e+00, 5.4844e+00,
-6.4160e-01],
[ 7.0312e-01, -2.8652e+00, -8.0156e+00, -8.9922e+00,
-3.3672e+00],
[-7.0742e+00, 4.9648e+00, -1.4941e+00, 8.7500e+00,
6.4609e+00],
[ 1.9688e+00, -5.6250e+00, 3.5332e+00, 5.5977e+00,
1.1338e+00],
[ 4.5703e-01, 6.0039e+00, -2.5938e+00, 5.8984e+00,
6.2305e+00],
[ 3.6484e+00, -6.8359e+00, -1.6611e+00, 3.8496e+00,
-7.2266e+00],
[ 1.6084e+00, -7.9297e+00, -5.4844e+00, -2.2676e+00,
-6.4258e+00],
[-6.5469e+00, 7.3477e+00, -1.2393e+00, 5.3516e+00,
-7.2695e+00],
[-3.7793e-01, 7.9609e+00, 4.4844e+00, -2.3984e+00,
-7.7422e+00]]], dtype=float16)}
# Set up the inference session
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
onnx_model = onnx.ModelProto()
google.protobuf.text_format.Parse(onnx_model_text, onnx_model)
# Uncomment this line to save the model to a file for examination
# onnx.save_model(onnx_model, "test_output_match_opinfo__log_softmax_with_dtype_cpu_float16.onnx")
onnx.checker.check_model(onnx_model)
session = ort.InferenceSession(onnx_model.SerializeToString(), session_options, providers=("CPUExecutionProvider",))
# Run the model
for _ in range(N):
ort_outputs = session.run(None, ort_inputs) Full error stack
Environment
|
Describe the issue
ORT reports error:
E onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. In Node, ("", ReduceMax, "", -1) : ("_inline_aten_special_softmaxself_5": tensor(float16),) -> ("_inline_SoftmaxX_ReduceMax",) , Error Unrecognized attribute: axes for operator ReduceMax
when the graph is
Softmax has two functions defined for different opset versions (13 and 18). Since ReduceMax has a new input
axes
in opset18, I suspect ORT was not taking the correct function fromSoftmax
's definition:https://github.com/onnx/onnx/blob/e2e97ccc36211e72c3607d46f71782542d1df5ee/onnx/defs/math/defs.cc#L1017-L1031
cc @gramalingam
To reproduce
@justinchuby
Urgency
No response
Platform
Linux
OS Version
N/A
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.15.0
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
No response
The text was updated successfully, but these errors were encountered: