Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVM EP] Improved usability of TVM EP #10241

Merged
merged 6 commits into from
Jan 25, 2022
Merged

Conversation

KJlaccHoeUM9l
Copy link
Contributor

Working between the C++ and Python parts in TVM EP is done using the PackedFunc and Registry classes. In order to use a Python function in C++ code, it must be registered in the global table of functions.
Registration is carried out through the JIT interface, so it is necessary to call special functions for registration. To do this, we need to make the following import:

import onnxruntime.providers.stvm # nessesary to register tvm_onnx_import_and_compile and others

In order not to write this line at the beginning of every script where TVM EP is required, it was moved to __init__.py.
Thus, only one import needs to be called:

import onnxruntime

@@ -44,6 +44,15 @@
except ImportError:
pass

try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to have a mechanism only available when onnxruntime is compiled for this EP and not something always in the code whatever the compilation are.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @xadupre!
Could you take a look at the changes? Maybe you also had some ideas about this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks ok. At least, users know this code is only needed for this provider. I had in mind something which does not add extra code when onnxruntime is compiled without this provider. Something like an import made from C++ surrounded by #ifdef ... STVM ... #endif.

@KJlaccHoeUM9l
Copy link
Contributor Author

Hello @xadupre!
Most of the model processing in onnxruntime is done in native C++. However, not all the steps that are required for TVM EP to work can be done in the C++ part. Some functionality is written only for python, so the EP implementation contains bindings that link the desired functionality between C++ and Python. These bindings are in this file:

  1. tvm_run_with_benchmark
  2. tvm_run
  3. tvm_onnx_import_and_compile

The first two functions can be removed, for this we already have a corresponding task in Jira. However, the third function (tvm_onnx_import_and_compile) internally uses TVM functions that are only provided for Python and cannot be easily ported to the C++ part.
Thus, it turns out that for the TVM EP to work at this stage, a link between Python and C++ is required. This interaction is carried out using the PackedFunc and Registry classes. To use a Python function in C++ code, it must be registered in the global function table. Registration is carried out through the JIT interface, so for registration it is necessary to call special functions:

@tvm.register_func("tvm_onnx_import_and_compile")
def onnx_compile(...):
...

This call must occur at runtime, so we cannot move this part to the compilation stage, where it is possible to use #ifdef.

Unfortunately, at the moment we are not aware of many alternatives that can solve this problem:

  • Leave everything as it is, so that the user himself calls the registration of the necessary functionality by writing import onnxruntime.providers.stvm -- this is not very convenient for the end user, because it is easy to forget about it. In this case, when trying to use this function, there will be an error associated with null_ptr . In order to tell the user why this happened, you can expand the message in your PR to tell the user the string they should insert into their Python script;
  • Rewriting some of the functionality we need in C++ -- is a long and difficult way;
  • Move the tvm_onnx_import_and_compile function to TVM so that it is registered at the stage of loading into Python libtvm.so -- in our opinion, this is not quite the right way, because this is a functionality that is needed for onnxruntime to work, not for TVM;
  • Dynamically expanding the init.py file during the build of the onnxruntime project with the flag set for TVM EP -- is too dirty a solution.

In our opinion, the solution that we proposed in this PR is the best of the voiced alternatives. It does not require any actions from the user and does not require deep reworking of the code in ORT and TVM.

@tmoreau89, @jwfromm what do you think about this? Do you have any comments or suggestions?

xadupre
xadupre previously approved these changes Jan 18, 2022
@xadupre
Copy link
Member

xadupre commented Jan 18, 2022

Thanks for the clarifications.

@KJlaccHoeUM9l
Copy link
Contributor Author

@xadupre, it looks like while working on another task, I found a better way to solve it.
It seems that conditional code expansion when building onnxruntime for TVM EP can be done via the _ld_preload.py file when building the wheel package via setup.py.
To do this, we need to insert the changes from this PR into the _ld_preload.py file, similar to how it was done for TensorRT.
We tested this assumption directly (without changing setup.py) and it works. Now we need to modify the setup.py file and check if it works correctly.

@xadupre
Copy link
Member

xadupre commented Jan 18, 2022

Will it work for Windows or MacOS ?

@gramalingam
Copy link
Contributor

Hi, I don't fully understand the details here. But I assume that this supports initiating execution of onnxruntime from all languages (supported by ort), not just python, right?

@KJlaccHoeUM9l
Copy link
Contributor Author

Will it work for Windows or MacOS ?

Based on this file, this import will be called regardless of the platform.

@KJlaccHoeUM9l
Copy link
Contributor Author

Hi, I don't fully understand the details here. But I assume that this supports initiating execution of onnxruntime from all languages (supported by ort), not just python, right?

Hello @gramalingam!
At the moment, all development is under Linux and Python, because now the EP is in alpha preview. Some other languages (such as Java) have small infrastructural declarations, but they have not been tested in any way.
Therefore, most likely other languages will not work at this stage.

@tmoreau89
Copy link
Contributor

As @KJlaccHoeUM9l said, much of the validation work we've done has been on Linux using Python. I suggest that expanding the scope of OSes supported and languages/envs. supported can be tackled in follow up work/PRs.

@KJlaccHoeUM9l
Copy link
Contributor Author

Hello @xadupre! Could you take a look at the changes?
We checked the solution I mentioned earlier, everything works correctly. The advantage of this approach is that if the flag for TVM EP is not set, then no one will see this code.

setup.py Outdated
@@ -145,6 +145,28 @@ def _rewrite_ld_preload_tensorrt(self, to_preload):
f.write(' import os\n')
f.write(' os.environ["ORT_TENSORRT_UNAVAILABLE"] = "1"\n')

def _rewrite_ld_preload_tvm(self):
with open('onnxruntime/capi/_ld_preload.py', 'a') as f:
f.write('import warnings\n\n')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest the following to make this part easier to read and to modify.

f.write(textwrap.dedent("""
"""))

xadupre
xadupre previously approved these changes Jan 20, 2022
@KJlaccHoeUM9l
Copy link
Contributor Author

@xadupre, your comment has been taken into account. Could you please re-approve the PR?

@KJlaccHoeUM9l KJlaccHoeUM9l changed the title [TVM EP] WIP: improved usability of TVM EP [TVM EP] Improved usability of TVM EP Jan 21, 2022
@xadupre
Copy link
Member

xadupre commented Jan 21, 2022

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline, MacOS CI Pipeline, ONNX Runtime Web CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline

@xadupre
Copy link
Member

xadupre commented Jan 21, 2022

/azp run Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, Windows WebAssembly CI Pipeline, orttraining-amd-gpu-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed, onnxruntime-python-checks-ci-pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 8 pipeline(s).

@xadupre xadupre merged commit a0fe4a7 into microsoft:master Jan 25, 2022
petersalas pushed a commit to octoml/onnxruntime that referenced this pull request Nov 7, 2022
* improved usability of TVM EP
* moved technical import under a condition related to TVM EP only
* Revert "moved technical import under a condition related to TVM EP only"
* add conditional _ld_preload.py file extension for TVM EP
* improve readability of inserted code

(cherry picked from commit a0fe4a7)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants