Skip to content

Commit

Permalink
add conditional _ld_preload.py file extension for TVM EP
Browse files Browse the repository at this point in the history
  • Loading branch information
KJlaccHoeUM9l committed Jan 20, 2022
1 parent 89e1769 commit 313e26d
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,28 @@ def _rewrite_ld_preload_tensorrt(self, to_preload):
f.write(' import os\n')
f.write(' os.environ["ORT_TENSORRT_UNAVAILABLE"] = "1"\n')

def _rewrite_ld_preload_tvm(self):
with open('onnxruntime/capi/_ld_preload.py', 'a') as f:
f.write('import warnings\n\n')
f.write('try:\n')
f.write(' # This import is necessary in order to delegate the loading of libtvm.so to TVM.\n')
f.write(' import tvm\n')
f.write('except ImportError as e:\n')
f.write(' warnings.warn(\n')
f.write(' f"WARNING: Failed to import TVM, libtvm.so was not loaded. More details: {e}"\n')
f.write(' )\n\n')
f.write('try:\n')
f.write(' # Working between the C++ and Python parts in TVM EP is done using\n')
f.write(' # the PackedFunc and Registry classes. In order to use a Python function in C++ code,\n')
f.write(' # it must be registered in the global table of functions. Registration is carried out\n')
f.write(' # through JIT interface, so it is necessary to call special functions for registration.\n')
f.write(' # To do this, we need to make the following import.\n')
f.write(' import onnxruntime.providers.stvm\n')
f.write('except ImportError as e:\n')
f.write(' warnings.warn(\n')
f.write(' f"WARNING: Failed to register python functions for TVM EP. More details: {e}"\n')
f.write(' )\n')

def run(self):
if is_manylinux:
source = 'onnxruntime/capi/onnxruntime_pybind11_state.so'
Expand Down Expand Up @@ -207,6 +229,8 @@ def run(self):
self._rewrite_ld_preload(to_preload)
self._rewrite_ld_preload_cuda(to_preload_cuda)
self._rewrite_ld_preload_tensorrt(to_preload_tensorrt)
if package_name == 'onnxruntime-tvm':
self._rewrite_ld_preload_tvm()
_bdist_wheel.run(self)
if is_manylinux and not disable_auditwheel_repair:
file = glob(path.join(self.dist_dir, '*linux*.whl'))[0]
Expand Down

0 comments on commit 313e26d

Please sign in to comment.