Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVM EP] Improved usability of TVM EP #10241

Merged
merged 6 commits into from
Jan 25, 2022
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,28 @@ def _rewrite_ld_preload_tensorrt(self, to_preload):
f.write(' import os\n')
f.write(' os.environ["ORT_TENSORRT_UNAVAILABLE"] = "1"\n')

def _rewrite_ld_preload_tvm(self):
with open('onnxruntime/capi/_ld_preload.py', 'a') as f:
f.write('import warnings\n\n')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest the following to make this part easier to read and to modify.

f.write(textwrap.dedent("""
"""))

f.write('try:\n')
f.write(' # This import is necessary in order to delegate the loading of libtvm.so to TVM.\n')
f.write(' import tvm\n')
f.write('except ImportError as e:\n')
f.write(' warnings.warn(\n')
f.write(' f"WARNING: Failed to import TVM, libtvm.so was not loaded. More details: {e}"\n')
f.write(' )\n\n')
f.write('try:\n')
f.write(' # Working between the C++ and Python parts in TVM EP is done using\n')
f.write(' # the PackedFunc and Registry classes. In order to use a Python function in C++ code,\n')
f.write(' # it must be registered in the global table of functions. Registration is carried out\n')
f.write(' # through JIT interface, so it is necessary to call special functions for registration.\n')
f.write(' # To do this, we need to make the following import.\n')
f.write(' import onnxruntime.providers.stvm\n')
f.write('except ImportError as e:\n')
f.write(' warnings.warn(\n')
f.write(' f"WARNING: Failed to register python functions for TVM EP. More details: {e}"\n')
f.write(' )\n')

def run(self):
if is_manylinux:
source = 'onnxruntime/capi/onnxruntime_pybind11_state.so'
Expand Down Expand Up @@ -207,6 +229,8 @@ def run(self):
self._rewrite_ld_preload(to_preload)
self._rewrite_ld_preload_cuda(to_preload_cuda)
self._rewrite_ld_preload_tensorrt(to_preload_tensorrt)
if package_name == 'onnxruntime-tvm':
self._rewrite_ld_preload_tvm()
_bdist_wheel.run(self)
if is_manylinux and not disable_auditwheel_repair:
file = glob(path.join(self.dist_dir, '*linux*.whl'))[0]
Expand Down