Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

调用convert_to_onnx_and_check接口报错 "Error No Op registered for Range with domain_version of 10" #98

Closed
CPFLAME opened this issue Oct 10, 2022 · 3 comments · Fixed by #99

Comments

@CPFLAME
Copy link
Contributor

CPFLAME commented Oct 10, 2022

在调用 convert_to_onnx_and_check

convert_to_onnx_and_check(t5_graph,
                  external_data=False, 
                  opset=None, 
                  flow_weight_dir=None, 
                  onnx_model_path="./", 
                  dynamic_batch_size=False)

报错:

onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from ./model.onnx failed:This is an invalid model. In Node, ("model.t5_model.encoder.layers.0.self_attention-arange-18", Range, "", -1) : ("start85": tensor(int64),"limit80": tensor(int64),"delta81": tensor(int64),) -> ("model.t5_model.encoder.layers.0.self_attention-arange-18/out_0",) , Error No Op registered for Range with domain_version of 10

在网上查找了以后, NVIDIA/TensorRT#1658 , 发现可能onnx的版本原因.
修改调用接口为:

convert_to_onnx_and_check(t5_graph,
                  external_data=False, 
                  opset=11, 
                  flow_weight_dir=None, 
                  onnx_model_path="./", 
                  dynamic_batch_size=False)

出现新的报错:

Traceback (most recent call last):
  File "libai/onnx_export/t5_to_onnx.py", line 57, in <module>
    convert_to_onnx_and_check(t5_graph,
  File "/home/chengpeng/miniconda3/envs/libai/lib/python3.8/site-packages/oneflow_onnx-0.5.5-py3.8.egg/oneflow_onnx/oneflow2onnx/util.py", line 99, in convert_to_onnx_and_check
  File "/home/chengpeng/miniconda3/envs/libai/lib/python3.8/site-packages/oneflow_onnx-0.5.5-py3.8.egg/oneflow_onnx/oneflow2onnx/util.py", line 29, in run_onnx
  File "/home/chengpeng/miniconda3/envs/libai/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 347, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/chengpeng/miniconda3/envs/libai/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 384, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from ./model.onnx failed:This is an invalid model. Type Error: Type 'tensor(int64)' of input parameter (model.t5_model.encoder.layers.0.self_attention-scalar_add-25/out_0) of operator (Sum) in node (model.t5_model.encoder.layers.0.self_attention-add_n-39) is invalid.
@CPFLAME
Copy link
Contributor Author

CPFLAME commented Oct 10, 2022

目前看来是 flow.where 不能和 flow.tensor 相加 导致的, 最小复现代码

import oneflow as flow
from oneflow import nn
from oneflow_onnx.oneflow2onnx.util import export_onnx_model, convert_to_onnx_and_check


class DemoModel(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.is_small = flow.ones(
            3, 3,
            sbp=flow.sbp.broadcast,
            placement=flow.placement("cuda", ranks=[0]),
            dtype=flow.bool
        )
        self.relative_postion_if_large = flow.ones(
            3, 3,
            sbp=flow.sbp.broadcast,
            placement=flow.placement("cuda", ranks=[0]),
            dtype=flow.int64
        )

    
    def forward(self, relative_position):
        relative_buckets = flow.ones(
            3, 3,
            sbp=flow.sbp.broadcast,
            placement=flow.placement("cuda", ranks=[0]),
            dtype=flow.int64
        )
        temp = flow.where(
            self.is_small,
            relative_position,
            self.relative_postion_if_large
        )

        relative_buckets = relative_buckets + temp

        return relative_buckets

class ModelGraph(nn.Graph):
    def __init__(self, eager_model):
        super().__init__()
        self.model = eager_model

    def build(self, x):
        return self.model(x)

model = DemoModel()
model.eval()

input_tensor = flow.ones(
    3, 3,
    sbp=flow.sbp.broadcast,
    placement=flow.placement("cuda", ranks=[0]),
    dtype=flow.int64
)
res = model.forward(input_tensor)
print("eager output:", res)

modelgraph = ModelGraph(model)
# Build the static graph model
modelgraph._compile(
    input_tensor
)

convert_to_onnx_and_check(
    modelgraph,
    external_data=False,
    opset=11,
    flow_weight_dir=None,
    onnx_model_path="./",
    dynamic_batch_size=False,
)

报错:

loaded library: /lib/libibverbs.so.1
eager output: tensor([[2, 2, 2],
        [2, 2, 2],
        [2, 2, 2]], placement=oneflow.placement(type="cuda", ranks=[0]), sbp=(oneflow.sbp.broadcast,),
       dtype=oneflow.int64)
Traceback (most recent call last):
  File "test_onnx.py", line 66, in <module>
    convert_to_onnx_and_check(
  File "/home/chengpeng/oneflow_convert/oneflow_onnx/oneflow2onnx/util.py", line 99, in convert_to_onnx_and_check
    ipt_dict, onnx_res = run_onnx(onnx_model_path, ["CPUExecutionProvider"], ort_optimize=ort_optimize)
  File "/home/chengpeng/oneflow_convert/oneflow_onnx/oneflow2onnx/util.py", line 29, in run_onnx
    sess = ort.InferenceSession(onnx_model_path, sess_options=ort_sess_opt, providers=providers)
  File "/home/chengpeng/miniconda3/envs/libai/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 347, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/chengpeng/miniconda3/envs/libai/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 384, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from ./model.onnx failed:This is an invalid model. Type Error: Type 'tensor(int64)' of input parameter (model-where-1/out_0) of operator (Sum) in node (model-add_n-4) is invalid.

@BBuf
Copy link
Contributor

BBuf commented Oct 11, 2022

上述问题已经解决,但现在有一个更大的输入顺序不定的问题。https://github.com/Oneflow-Inc/OneTeam/issues/1722

@BBuf
Copy link
Contributor

BBuf commented Oct 13, 2022

上述问题均已解决。

@BBuf BBuf closed this as completed Oct 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants