Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVM EP] Improved usability of TVM EP #10241

Merged
merged 6 commits into from
Jan 25, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import platform
import subprocess
import sys
import textwrap
import datetime

from pathlib import Path
Expand Down Expand Up @@ -145,6 +146,33 @@ def _rewrite_ld_preload_tensorrt(self, to_preload):
f.write(' import os\n')
f.write(' os.environ["ORT_TENSORRT_UNAVAILABLE"] = "1"\n')

def _rewrite_ld_preload_tvm(self):
with open('onnxruntime/capi/_ld_preload.py', 'a') as f:
f.write(textwrap.dedent(
"""
import warnings

try:
# This import is necessary in order to delegate the loading of libtvm.so to TVM.
import tvm
except ImportError as e:
warnings.warn(
f"WARNING: Failed to import TVM, libtvm.so was not loaded. More details: {e}"
)
try:
# Working between the C++ and Python parts in TVM EP is done using the PackedFunc and
# Registry classes. In order to use a Python function in C++ code, it must be registered in
# the global table of functions. Registration is carried out through the JIT interface,
# so it is necessary to call special functions for registration.
# To do this, we need to make the following import.
import onnxruntime.providers.stvm
except ImportError as e:
warnings.warn(
f"WARNING: Failed to register python functions to work with TVM EP. More details: {e}"
)
"""
))

def run(self):
if is_manylinux:
source = 'onnxruntime/capi/onnxruntime_pybind11_state.so'
Expand Down Expand Up @@ -207,6 +235,8 @@ def run(self):
self._rewrite_ld_preload(to_preload)
self._rewrite_ld_preload_cuda(to_preload_cuda)
self._rewrite_ld_preload_tensorrt(to_preload_tensorrt)
if package_name == 'onnxruntime-tvm':
self._rewrite_ld_preload_tvm()
_bdist_wheel.run(self)
if is_manylinux and not disable_auditwheel_repair:
file = glob(path.join(self.dist_dir, '*linux*.whl'))[0]
Expand Down