forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from natke/custom-ops-docs
Stage PR microsoft#7636
- Loading branch information
Showing
1 changed file
with
60 additions
and
143 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,185 +1,102 @@ | ||
--- | ||
title: Export PyTorch model | ||
nav_exclude: true | ||
--- | ||
# Convert And Inference Pytorch model with CustomOps | ||
|
||
## Export PyTorch model with custom ONNX operators | ||
{: .no_toc } | ||
With [onnxruntime_customops](https://github.com/microsoft/onnxruntime-extensions) package, the PyTorch model with the operation cannot be converted into the standard ONNX operators still be converted and the converted ONNX model still can be run with ONNXRuntime, plus onnxruntime_customops package. This tutorial show it works | ||
|
||
This document explains the process of exporting PyTorch models with custom ONNX Runtime ops. The aim is to export a PyTorch model with operators that are not supported in ONNX, and extend ONNX Runtime to support these custom ops. | ||
## Converting | ||
Suppose there is a model which cannot be converted because there is no matrix inverse operation in ONNX standard opset. And the model will be defined like the following. | ||
|
||
Currently, a torch op can be exported as a custom operator using our custom op (symbolic) registration API. We can use this API to register custom ONNX Runtime ops under "com.microsoft" domain. | ||
|
||
## Contents | ||
{: .no_toc } | ||
```python | ||
import torch | ||
import torchvision | ||
|
||
* TOC placeholder | ||
{:toc} | ||
class CustomInverse(torch.nn.Module): | ||
def forward(self, x): | ||
return torch.inverse(x) + x | ||
``` | ||
|
||
### Export a Custom Op | ||
To export this model into ONNX format, we need register a custom op handler for pytorch.onn.exporter. | ||
|
||
In this example, we take Inverse operator as an example. To enable export of ```torch.inverse```, a symbolic function can be created and registered as part of custom ops: | ||
|
||
```python | ||
from torch.onnx import register_custom_op_symbolic | ||
|
||
|
||
def my_inverse(g, self): | ||
return g.op("com.microsoft::Inverse", self) | ||
return g.op("ai.onnx.contrib::Inverse", self) | ||
|
||
# register_custom_op_symbolic('<namespace>::inverse', my_inverse, <opset_version>) | ||
register_custom_op_symbolic('::inverse', my_inverse, 1) | ||
``` | ||
|
||
`<namespace>` is a part of the torch operator name. For standard torch operators, namespace can be omitted. | ||
Then, invoke the exporter | ||
|
||
`com.microsoft` should be used as the custom opset domain for ONNX Runtime ops. You can choose the custom opset version during op registration. | ||
|
||
All symbolics for ONNX Runtime custom ops are defined in `tools/python/register_custom_ops_pytorch_exporter.py`. | ||
|
||
If you are adding a symbolic function for a new custom op, add the function to this file. | ||
|
||
### Extend ONNX Runtime with Custom Ops | ||
|
||
The next step is to add op schema and kernel implementation in ONNX Runtime. | ||
Consider the Inverse custom op as an example added in: | ||
https://github.com/microsoft/onnxruntime/pull/3485 | ||
|
||
Custom op schema and shape inference function should be added in https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/core/graph/contrib_ops/contrib_defs.cc using `ONNX_CONTRIB_OPERATOR_SCHEMA`. | ||
```python | ||
import io | ||
import onnx | ||
|
||
```c++ | ||
ONNX_CONTRIB_OPERATOR_SCHEMA(Inverse) | ||
.SetDomain(kMSDomain) // kMSDomain = "com.microsoft" | ||
.SinceVersion(1) // Same version used at op (symbolic) registration | ||
... | ||
x0 = torch.randn(3, 3) | ||
# Export model to ONNX | ||
f = io.BytesIO() | ||
t_model = CustomInverse() | ||
torch.onnx.export(t_model, (x0, ), f, opset_version=12) | ||
onnx_model = onnx.load(io.BytesIO(f.getvalue())) | ||
``` | ||
|
||
To comply with ONNX guideline for new operators, a new operator should have complete reference implementation tests and shape inference tests. | ||
Reference implementation python tests should be added in: | ||
https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/test/python/contrib_ops | ||
E.g.: https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/test/python/contrib_ops/onnx_test_trilu.py | ||
Shape inference C++ tests should be added in: | ||
https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/test/contrib_ops | ||
E.g.: https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/test/contrib_ops/trilu_shape_inference_test.cc | ||
The operator kernel should be implemented using ```Compute``` function | ||
under contrib namespace in `https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/contrib_ops/cpu/<operator>.cc` | ||
for CPU and `https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/contrib_ops/cuda/<operator>.cc` for CUDA. | ||
```c++ | ||
namespace onnxruntime { | ||
namespace contrib { | ||
Now, we got a ONNX model in the memory, and it can be save into a disk file by 'onnx.save_model(onnx_model, <file_path>) | ||
|
||
class Inverse final : public OpKernel { | ||
public: | ||
explicit Inverse(const OpKernelInfo& info) : OpKernel(info) {} | ||
Status Compute(OpKernelContext* ctx) const override; | ||
## Inference | ||
This converted model cannot directly run the onnxruntime due to the custom operator. but it can run with onnxruntime_customops easily. | ||
|
||
private: | ||
... | ||
}; | ||
Firstly, let define a PyOp function to inteprete the custom op node in the ONNNX model. | ||
|
||
ONNX_OPERATOR_KERNEL_EX( | ||
Inverse, | ||
kMSDomain, | ||
1, | ||
kCpuExecutionProvider, | ||
KernelDefBuilder() | ||
.TypeConstraint("T", BuildKernelDefConstraints<float, double, MLFloat16>()), | ||
Inverse); | ||
|
||
Status Inverse::Compute(OpKernelContext* ctx) const { | ||
... // kernel implementation | ||
} | ||
```python | ||
import numpy | ||
from onnxruntime_customops import onnx_op, PyOp | ||
@onnx_op(op_type="Inverse") | ||
def inverse(x): | ||
# the user custom op implementation here: | ||
return numpy.linalg.inv(x) | ||
|
||
} // namespace contrib | ||
} // namespace onnxruntime | ||
``` | ||
|
||
Operator kernel should be registered in https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/contrib_ops/cpu_contrib_kernels.cc for CPU and https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/contrib_ops/cuda_contrib_kernels.cc for CUDA. | ||
|
||
Now you should be able to build and install ONNX Runtime to start using your custom op. | ||
|
||
### ONNX Runtime Tests | ||
|
||
ONNX Runtime custom op kernel tests should be added in: https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/test/contrib_ops/<operator>_test.cc | ||
|
||
```c++ | ||
namespace onnxruntime { | ||
namespace test { | ||
* **ONNX Inference** | ||
|
||
// Add a comprehensive set of unit tests for custom op kernel implementation | ||
|
||
TEST(InverseContribOpTest, two_by_two_float) { | ||
OpTester test("Inverse", 1, kMSDomain); // custom opset version and domain | ||
test.AddInput<float>("X", {2, 2}, {4, 7, 2, 6}); | ||
test.AddOutput<float>("Y", {2, 2}, {0.6f, -0.7f, -0.2f, 0.4f}); | ||
test.Run(); | ||
} | ||
|
||
... | ||
|
||
} // namespace test | ||
} // namespace onnxruntime | ||
```python | ||
from onnxruntime_customops import PyOrtFunction | ||
onnx_fn = PyOrtFunction.from_model(onnx_model) | ||
y = onnx_fn(x0.numpy()) | ||
print(y) | ||
``` | ||
|
||
### Test model Export End to End | ||
Once the custom op is registered in the exporter and implemented in ONNX Runtime, you should be able to export it as part of you ONNX model and run it with ONNX Runtime. | ||
[[-3.081008 0.20269153 0.42009977] | ||
[-3.3962293 2.5986686 2.4447646 ] | ||
[ 0.7805753 -0.20394287 -2.7528977 ]] | ||
|
||
|
||
Below you can find a sample script for exporting and running the inverse operator as part of a model. | ||
* **Compare the result with Pytorch** | ||
|
||
The exported model includes a combination of ONNX standard ops and the custom ops. | ||
This test also compares the output of PyTorch model with ONNX Runtime outputs to test both the operator export and implementation. | ||
|
||
```python | ||
import torch | ||
import onnxruntime | ||
import io | ||
import numpy | ||
class CustomInverse(torch.nn.Module): | ||
def forward(self, x): | ||
return torch.inverse(x) + x | ||
x = torch.randn(3, 3) | ||
# Export model to ONNX | ||
f = io.BytesIO() | ||
torch.onnx.export(CustomInverse(), (x,), f) | ||
model = CustomInverse() | ||
pt_outputs = model(x) | ||
# Run the exported model with ONNX Runtime | ||
ort_sess = onnxruntime.InferenceSession(f.getvalue()) | ||
ort_inputs = dict((ort_sess.get_inputs()[i].name, input.cpu().numpy()) for i, input in enumerate((x,))) | ||
ort_outputs = ort_sess.run(None, ort_inputs) | ||
# Validate PyTorch and ONNX Runtime results | ||
numpy.testing.assert_allclose(pt_outputs.cpu().numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05) | ||
t_y = t_model(x0) | ||
numpy.testing.assert_almost_equal(t_y, y, decimal=5) | ||
``` | ||
|
||
By default, the opset version will be set to ``1`` for custom opsets. If you'd like to export your | ||
custom op to a higher opset version, you can specify the custom opset domain and version using | ||
the ``custom_opsets argument`` when calling the export API. Note that this is different than the opset | ||
version associated with default ```ONNX``` domain. | ||
## Implement the customop in C++ (optional) | ||
To make the ONNX model with the CustomOp runn on all other language supported by ONNX Runtime and be independdent of Python, a C++ implmentation is needed, check here for the [inverse.hpp](https://github.com/microsoft/onnxruntime-extensions/blob/main/operators/math/inverse.hpp) for an example on how to do that. | ||
|
||
|
||
```python | ||
torch.onnx.export(CustomInverse(), (x,), f, custom_opsets={"com.microsoft": 5}) | ||
from onnxruntime_customops import enable_custom_op | ||
# disable the PyOp function and run with the C++ function | ||
enable_custom_op(False) | ||
y = onnx_fn(x0.numpy()) | ||
print(y) | ||
``` | ||
|
||
Note that you can export a custom op to any version >= the opset version used at registration. | ||
|
||
We have a set of tests for export and output validation of ONNX models with ONNX Runtime custom ops in | ||
``tools/test/test_test_custom_ops_pytorch_exporter.py``. If you're adding a new custom operator, please | ||
make sure to include tests in this file. | ||
|
||
You can run these tests using the command: | ||
|
||
```bash | ||
PYTHONPATH=<path_to_onnxruntime/tools> pytest -v test_custom_ops_pytorch_exporter.py | ||
``` | ||
[[-3.081008 0.20269153 0.42009977] | ||
[-3.3962293 2.5986686 2.4447646 ] | ||
[ 0.7805753 -0.20394287 -2.7528977 ]] | ||
|