Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Check failed: type_code_ == kTVMObjectHandle (0 vs. 8) : expected Object but got int #14717

Open
hcms1994 opened this issue Apr 25, 2023 · 3 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@hcms1994
Copy link

hcms1994 commented Apr 25, 2023

I encountered this error when:
mod, config = partition_for_tensorrt(mod, params)

error message:
Traceback (most recent call last):
File "autoTVM_tune_relay_tensorrt.py", line 105, in
tune_and_evaluate()
File "autoTVM_tune_relay_tensorrt.py", line 67, in tune_and_evaluate
mod, config = partition_for_tensorrt(mod, params)
File "/home/jiyingyu/TVM_work/code/tvm-0.11.0-release/python/tvm/ir/module.py", line 110, in getitem
return _ffi_api.Module_LookupDef(self, var)
File "/home/jiyingyu/TVM_work/code/tvm-0.11.0-release/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in call
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
2: TVMFuncCall
1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::TypeData (tvm::IRModule, tvm::GlobalTypeVar)>::AssignTypedLambda<tvm::{lambda(tvm::IRModule, tvm::GlobalTypeVar)#6}>(tvm::{lambda(tvm::IRModule, tvm::GlobalTypeVar)#6}, std::__cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::_cxx11::basic_string<char, std::char_traits, std::allocator >, tvm::runtime::TVMRetValue)
0: tvm::runtime::TVMMovableArgValueWithContext
::operator tvm::GlobalTypeVartvm::GlobalTypeVar() const
3: TVMFuncCall
2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::TypeData (tvm::IRModule, tvm::GlobalTypeVar)>::AssignTypedLambda<tvm::{lambda(tvm::IRModule, tvm::GlobalTypeVar)#6}>(tvm::{lambda(tvm::IRModule, tvm::GlobalTypeVar)#6}, std::__cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::cxx11::basic_string<char, std::char_traits, std::allocator >, tvm::runtime::TVMRetValue)
1: tvm::runtime::TVMMovableArgValueWithContext
::operator tvm::GlobalTypeVartvm::GlobalTypeVar() const
0: tvm::GlobalTypeVar tvm::runtime::TVMPODValue
::AsObjectReftvm::GlobalTypeVar() const
File "/home/jiyingyu/TVM_work/code/tvm-0.11.0-release/include/tvm/runtime/packed_func.h", line 777
TVMError: In function ir.Module_LookupDef(0: IRModule, 1: GlobalTypeVar) -> relay.TypeData: error while converting argument 1: [10:47:20] /home/jiyingyu/TVM_work/code/tvm-0.11.0-release/include/tvm/runtime/packed_func.h:1886:

An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html

Check failed: type_code_ == kTVMObjectHandle (0 vs. 8) : expected Object but got int

Have you got any idea why this error happen?

@hcms1994 hcms1994 added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Apr 25, 2023
@masahi
Copy link
Member

masahi commented Apr 25, 2023

Please provide a full repro script.

@hcms1994
Copy link
Author

hcms1994 commented Apr 25, 2023

test.txt
Thank you for your reply, Due to the large size of the model file, I uploaded it to Baidu Netdisk
Baidu Netdisk link: https://pan.baidu.com/s/1Td-WjIY8wvjOeO2HoDTsAQ password: ekql

The full repro script is as follows: test.txt
`
import os
import numpy as np
import tvm
from tvm import relay, autotvm
import tvm.relay.testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
import tvm.contrib.graph_executor as runtime
import onnx
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt

def get_tvm_cloud_model_network():
input_name0 = 'x2paddle_input0'
input_shape0 = (1, 1, 16, 80)
input_name1 = 'x2paddle_input1'
input_shape1 = (1, 6, 192, 192)
onnx_model = onnx.load_model('shuziren_v1.onnx')
shape_dict = {}
shape_dict[input_name0] = input_shape0
shape_dict[input_name1] = input_shape1
mod, params = relay.frontend.from_onnx(onnx_model, shape = shape_dict)
return mod, params, [input_shape0, input_shape1], {}

target = tvm.target.cuda(arch="sm_52")
dtype = "float32"

def tune_and_evaluate():
print("Extract tasks...")
mod, params, input_shape, out_shape = get_tvm_cloud_model_network()
print("get network success...")

mod, config = partition_for_tensorrt(mod, params)
for i in range(1):
    print("Compile...")
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build_module.build(mod, target=target, params=params)

    lib.export_library("shuziren_v1_default_agx_cuda_tensorrt_3.so")

    print("load parameters...")
    dev = tvm.device(str(target), 0)
    print(dev)
    print("device is right")
    module = runtime.GraphModule(lib["default"](dev))
    print("module is right")

    data_tvm0 = tvm.nd.array((np.random.uniform(size=input_shape[0])).astype('float32'))
    module.set_input("x2paddle_input0", data_tvm0)

    data_tvm1 = tvm.nd.array((np.random.uniform(size=input_shape[1])).astype('float32'))
    module.set_input("x2paddle_input1", data_tvm1)

    print("Evaluate inference time cost...")
    ftimer = module.module.time_evaluator("run", dev, number=1, repeat=600)
    prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
    print(
        "Mean inference time (std dev): %.2f ms (%.2f ms)"
        % (np.mean(prof_res), np.std(prof_res))
    )

tune_and_evaluate()
`

@digital-nomad-cheng
Copy link
Contributor

The reason for this error is that the docs for partition_for_tensorrt is out dated.
The correct usage should be:

mod = partition_for_tensorrt(mod, params)
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

As can be shown here:

mod = partition_for_tensorrt(mod, params)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

3 participants