From 313e26d056334dcde6d85177803e5d8f48d741bd Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Thu, 20 Jan 2022 17:04:02 +0300 Subject: [PATCH] add conditional _ld_preload.py file extension for TVM EP --- setup.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/setup.py b/setup.py index 2a1ec7f25d9b5..80af93760ecad 100644 --- a/setup.py +++ b/setup.py @@ -145,6 +145,28 @@ def _rewrite_ld_preload_tensorrt(self, to_preload): f.write(' import os\n') f.write(' os.environ["ORT_TENSORRT_UNAVAILABLE"] = "1"\n') + def _rewrite_ld_preload_tvm(self): + with open('onnxruntime/capi/_ld_preload.py', 'a') as f: + f.write('import warnings\n\n') + f.write('try:\n') + f.write(' # This import is necessary in order to delegate the loading of libtvm.so to TVM.\n') + f.write(' import tvm\n') + f.write('except ImportError as e:\n') + f.write(' warnings.warn(\n') + f.write(' f"WARNING: Failed to import TVM, libtvm.so was not loaded. More details: {e}"\n') + f.write(' )\n\n') + f.write('try:\n') + f.write(' # Working between the C++ and Python parts in TVM EP is done using\n') + f.write(' # the PackedFunc and Registry classes. In order to use a Python function in C++ code,\n') + f.write(' # it must be registered in the global table of functions. Registration is carried out\n') + f.write(' # through JIT interface, so it is necessary to call special functions for registration.\n') + f.write(' # To do this, we need to make the following import.\n') + f.write(' import onnxruntime.providers.stvm\n') + f.write('except ImportError as e:\n') + f.write(' warnings.warn(\n') + f.write(' f"WARNING: Failed to register python functions for TVM EP. More details: {e}"\n') + f.write(' )\n') + def run(self): if is_manylinux: source = 'onnxruntime/capi/onnxruntime_pybind11_state.so' @@ -207,6 +229,8 @@ def run(self): self._rewrite_ld_preload(to_preload) self._rewrite_ld_preload_cuda(to_preload_cuda) self._rewrite_ld_preload_tensorrt(to_preload_tensorrt) + if package_name == 'onnxruntime-tvm': + self._rewrite_ld_preload_tvm() _bdist_wheel.run(self) if is_manylinux and not disable_auditwheel_repair: file = glob(path.join(self.dist_dir, '*linux*.whl'))[0]