diff --git a/.gitignore b/.gitignore index cb8962e677..f8fe01ca38 100644 --- a/.gitignore +++ b/.gitignore @@ -33,4 +33,6 @@ __pycache__ *.egg-info dist bdist -py/trtorch/_version.py \ No newline at end of file +py/trtorch/_version.py +py/wheelhouse +py/.eggs \ No newline at end of file diff --git a/README.md b/README.md index f06d9760a1..91b215f1c1 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,16 @@ A tarball with the include files and library can then be found in bazel-bin bazel run //cpp/trtorchexec -- $(realpath ) ``` +## Compiling the Python Package + +To compile the python package for your local machine, just run `python3 setup.py install` in the `//py` directory. +To build wheel files for different python versions, first build the Dockerfile in ``//py`` then run the following +command +``` +docker run -it -v$(pwd)/..:/workspace/TRTorch build_trtorch_wheel /bin/bash /workspace/TRTorch/py/build_whl.sh +``` +Python compilation expects using the tarball based compilation strategy from above. + ## How do I add support for a new op... ### In TRTorch? diff --git a/py/Dockerfile b/py/Dockerfile new file mode 100644 index 0000000000..7e60e9b4f5 --- /dev/null +++ b/py/Dockerfile @@ -0,0 +1,14 @@ +FROM pytorch/manylinux-cuda102 + +RUN yum install -y ninja-build + +RUN wget https://copr.fedorainfracloud.org/coprs/vbatts/bazel/repo/epel-7/vbatts-bazel-epel-7.repo \ + && mv vbatts-bazel-epel-7.repo /etc/yum.repos.d/ + +RUN yum install -y bazel3 + +RUN mv /usr/bin/ninja-build /usr/bin/ninja + +RUN mkdir /workspace + +WORKDIR /workspace \ No newline at end of file diff --git a/py/LICENSE b/py/LICENSE new file mode 120000 index 0000000000..ea5b60640b --- /dev/null +++ b/py/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/py/README.md b/py/README.md new file mode 100644 index 0000000000..5f262c78a0 --- /dev/null +++ b/py/README.md @@ -0,0 +1,168 @@ +# trtorch + +> Ahead of Time (AOT) compiling for PyTorch JIT + +TRTorch is a compiler for PyTorch/TorchScript, targeting NVIDIA GPUs via NVIDIA's TensorRT Deep Learning Optimizer and Runtime. Unlike PyTorch's Just-In-Time (JIT) compiler, TRTorch is an Ahead-of-Time (AOT) compiler, meaning that before you deploy your TorchScript code, you go through an explicit compile step to convert a standard TorchScript program into an module targeting a TensorRT engine. TRTorch operates as a PyTorch extention and compiles modules that integrate into the JIT runtime seamlessly. After compilation using the optimized graph should feel no different than running a TorchScript module. You also have access to TensorRT's suite of configurations at compile time, so you are able to specify operating precision (FP32/FP16/INT8) and other settings for your module. + +## Example Usage + +``` python +import torch +import torchvision +import trtorch + +# Get a model +model = torchvision.models.alexnet(pretrained=True).eval().cuda() + +# Create some example data +data = torch.randn((1, 3, 224, 224)).to("cuda") + +# Trace the module with example data +traced_model = torch.jit.trace(model, [data]) + +# Compile module +compiled_trt_model = trtorch.compile(model, { + "input_shape": [data.shape], + "op_precision": torch.half, # Run in FP16 +}) + +results = compiled_trt_model(data.half()) +``` + +## Installation + +``` +pip3 install trtorch +``` + +## Under the Hood + +When a traced module is provided to TRTorch, the compiler takes the internal representation and transforms it into one like this: + +``` +graph(%input.2 : Tensor): + %2 : Float(84, 10) = prim::Constant[value=]() + %3 : Float(120, 84) = prim::Constant[value=]() + %4 : Float(576, 120) = prim::Constant[value=]() + %5 : int = prim::Constant[value=-1]() # x.py:25:0 + %6 : int[] = prim::Constant[value=annotate(List[int], [])]() + %7 : int[] = prim::Constant[value=[2, 2]]() + %8 : int[] = prim::Constant[value=[0, 0]]() + %9 : int[] = prim::Constant[value=[1, 1]]() + %10 : bool = prim::Constant[value=1]() # ~/.local/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %11 : int = prim::Constant[value=1]() # ~/.local/lib/python3.6/site-packages/torch/nn/functional.py:539:0 + %12 : bool = prim::Constant[value=0]() # ~/.local/lib/python3.6/site-packages/torch/nn/functional.py:539:0 + %self.classifer.fc3.bias : Float(10) = prim::Constant[value= 0.0464 0.0383 0.0678 0.0932 0.1045 -0.0805 -0.0435 -0.0818 0.0208 -0.0358 [ CUDAFloatType{10} ]]() + %self.classifer.fc2.bias : Float(84) = prim::Constant[value=]() + %self.classifer.fc1.bias : Float(120) = prim::Constant[value=]() + %self.feat.conv2.weight : Float(16, 6, 3, 3) = prim::Constant[value=]() + %self.feat.conv2.bias : Float(16) = prim::Constant[value=]() + %self.feat.conv1.weight : Float(6, 1, 3, 3) = prim::Constant[value=]() + %self.feat.conv1.bias : Float(6) = prim::Constant[value= 0.0530 -0.1691 0.2802 0.1502 0.1056 -0.1549 [ CUDAFloatType{6} ]]() + %input0.4 : Tensor = aten::_convolution(%input.2, %self.feat.conv1.weight, %self.feat.conv1.bias, %9, %8, %9, %12, %8, %11, %12, %12, %10) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %input0.5 : Tensor = aten::relu(%input0.4) # ~/.local/lib/python3.6/site-packages/torch/nn/functional.py:1063:0 + %input1.2 : Tensor = aten::max_pool2d(%input0.5, %7, %6, %8, %9, %12) # ~/.local/lib/python3.6/site-packages/torch/nn/functional.py:539:0 + %input0.6 : Tensor = aten::_convolution(%input1.2, %self.feat.conv2.weight, %self.feat.conv2.bias, %9, %8, %9, %12, %8, %11, %12, %12, %10) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0 + %input2.1 : Tensor = aten::relu(%input0.6) # ~/.local/lib/python3.6/site-packages/torch/nn/functional.py:1063:0 + %x.1 : Tensor = aten::max_pool2d(%input2.1, %7, %6, %8, %9, %12) # ~/.local/lib/python3.6/site-packages/torch/nn/functional.py:539:0 + %input.1 : Tensor = aten::flatten(%x.1, %11, %5) # x.py:25:0 + %27 : Tensor = aten::matmul(%input.1, %4) + %28 : Tensor = trt::const(%self.classifer.fc1.bias) + %29 : Tensor = aten::add_(%28, %27, %11) + %input0.2 : Tensor = aten::relu(%29) # ~/.local/lib/python3.6/site-packages/torch/nn/functional.py:1063:0 + %31 : Tensor = aten::matmul(%input0.2, %3) + %32 : Tensor = trt::const(%self.classifer.fc2.bias) + %33 : Tensor = aten::add_(%32, %31, %11) + %input1.1 : Tensor = aten::relu(%33) # ~/.local/lib/python3.6/site-packages/torch/nn/functional.py:1063:0 + %35 : Tensor = aten::matmul(%input1.1, %2) + %36 : Tensor = trt::const(%self.classifer.fc3.bias) + %37 : Tensor = aten::add_(%36, %35, %11) + return (%37) +(CompileGraph) +``` + +The graph has now been transformed from a collection of modules much like how your PyTorch Modules are collections of modules, each managing their own parameters into a single graph +with the parameters inlined into the graph and all of the operations laid out. TRTorch has also executed a number of optimizations and mappings to make the graph easier to translate +to TensorRT. From here the compiler can assemble the TensorRT engine by following the dataflow through the graph. + +When the graph construction phase is complete, TRTorch produces a serialized TensorRT engine. From here depending on the API, this engine is returned to the user or moves into the graph +construction phase. Here TRTorch creates a JIT Module to execute the TensorRT engine which will be instantiated and managed by the TRTorch runtime. + +Here is the graph that you get back after compilation is complete: + +``` +graph(%self.1 : __torch__.___torch_mangle_10.LeNet_trt, + %2 : Tensor): + %1 : int = prim::Constant[value=94106001690080]() + %3 : Tensor = trt::execute_engine(%1, %2) + return (%3) +(AddEngineToGraph) +``` + +You can see the call where the engine is executed, based on a constant which is the ID of the engine, telling JIT how to find the engine and the input tensor which will be fed to TensorRT. +The engine represents the exact same calculations as what is done by running a normal PyTorch module but optimized to run on your GPU. + +TRTorch converts from TorchScript by generating layers or subgraphs in correspondance with instructions seen in the graph. Converters are small modules of code used to map one specific +operation to a layer or subgraph in TensorRT. Not all operations are support, but if you need to implement one, you can in C++. + +## Registering Custom Converters + +Operations are mapped to TensorRT through the use of modular converters, a function that takes a node from a the JIT graph and produces an equivalent layer or subgraph in TensorRT. TRTorch +ships with a library of these converters stored in a registry, that will be executed depending on the node being parsed. For instance a `aten::relu(%input0.4)` instruction will trigger the +relu converter to be run on it, producing an activation layer in the TensorRT graph. But since this library is not exhaustive you may need to write your own to get TRTorch to support your module. + +Shipped with the TRTorch distribution are the internal core API headers. You can therefore access the converter registry and add a converter for the op you need. + +For example, if we try to compile a graph with a build of TRTorch that doesn’t support the flatten operation (`aten::flatten`) you may see this error: + +``` +terminate called after throwing an instance of 'trtorch::Error' +what(): [enforce fail at core/conversion/conversion.cpp:109] Expected converter to be true but got false +Unable to convert node: %input.1 : Tensor = aten::flatten(%x.1, %11, %5) # x.py:25:0 (conversion.AddLayer) +Schema: aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor) +Converter for aten::flatten requested, but no such converter was found. +If you need a converter for this operator, you can try implementing one yourself +or request a converter: https://www.github.com/NVIDIA/TRTorch/issues +``` + +We can register a converter for this operator in our application. All of the tools required to build a converter can be imported by including `trtorch/core/conversion/converters/converters.h`. +We start by creating an instance of the self-registering `class trtorch::core::conversion::converters::RegisterNodeConversionPatterns()` which will register converters in the global converter +registry, associating a function schema like `aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)` with a lambda that will take the state of the conversion, the +node/operation in question to convert and all of the inputs to the node and produces as a side effect a new layer in the TensorRT network. Arguments are passed as a vector of inspectable unions +of TensorRT ITensors and Torch IValues in the order arguments are listed in the schema. + +Below is a implementation of a `aten::flatten` converter that we can use in our application. You have full access to the Torch and TensorRT libraries in the converter implementation. So for example +we can quickly get the output size by just running the operation in PyTorch instead of implementing the full calculation outself like we do below for this flatten converter. + +```c++ +#include "torch/script.h" +#include "trtorch/trtorch.h" +#include "trtorch/core/conversion/converters/converters.h" + +static auto flatten_converter = trtorch::core::conversion::converters::RegisterNodeConversionPatterns() + .pattern({ + "aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)", + [](trtorch::core::conversion::ConversionCtx* ctx, + const torch::jit::Node* n, + trtorch::core::conversion::converters::args& args) -> bool { + auto in = args[0].ITensor(); + auto start_dim = args[1].unwrapToInt(); + auto end_dim = args[2].unwrapToInt(); + auto in_shape = trtorch::core::util::toVec(in->getDimensions()); + auto out_shape = torch::flatten(torch::rand(in_shape), start_dim, end_dim).sizes(); + + auto shuffle = ctx->net->addShuffle(*in); + shuffle->setReshapeDimensions(trtorch::core::util::toDims(out_shape)); + shuffle->setName(trtorch::core::util::node_info(n).c_str()); + + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0)); + return true; + } + }); +``` + +To use this converter in Python, it is recommended to use PyTorch’s [C++ / CUDA Extention](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions) template to wrap +your library of converters into a `.so` that you can load with `ctypes.CDLL()` in your Python application. + +You can find more information on all the details of writing converters in the contributors documentation ([Writing Converters](https://nvidia.github.io/TRTorch/contributors/writing_converters.html#writing-converters)). If you +find yourself with a large library of converter implementations, do consider upstreaming them, PRs are welcome and it would be great for the community to benefit as well. diff --git a/py/build_whl.sh b/py/build_whl.sh new file mode 100755 index 0000000000..87a58feb05 --- /dev/null +++ b/py/build_whl.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# Example usage: docker run -it -v$(pwd)/..:/workspace/TRTorch build_trtorch_wheel /bin/bash /workspace/TRTorch/py/build_whl.sh + +cd /workspace/TRTorch/py + +export CXX=g++ + +build_py35() { + /opt/python/cp35-cp35m/bin/python -m pip install -r requirements.txt + /opt/python/cp35-cp35m/bin/python setup.py bdist_wheel + #auditwheel repair --plat manylinux2014_x86_64 +} + +build_py36() { + /opt/python/cp36-cp36m/bin/python -m pip install -r requirements.txt + /opt/python/cp36-cp36m/bin/python setup.py bdist_wheel + #auditwheel repair --plat manylinux2014_x86_64 +} + +build_py37() { + /opt/python/cp37-cp37m/bin/python -m pip install -r requirements.txt + /opt/python/cp37-cp37m/bin/python setup.py bdist_wheel + #auditwheel repair --plat manylinux2014_x86_64 +} + +build_py38() { + /opt/python/cp38-cp38/bin/python -m pip install -r requirements.txt + /opt/python/cp38-cp38/bin/python setup.py bdist_wheel + #auditwheel repair --plat manylinux2014_x86_64 +} + +build_py35 +build_py36 +build_py37 +build_py38 \ No newline at end of file diff --git a/py/requirements.txt b/py/requirements.txt index 130ade137d..85b81ceeb5 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -1 +1 @@ -torch==1.5.0 +torch==1.5.0 \ No newline at end of file diff --git a/py/setup.py b/py/setup.py index d264f8066c..e77a8377f4 100644 --- a/py/setup.py +++ b/py/setup.py @@ -7,6 +7,7 @@ from setuptools.command.develop import develop from setuptools.command.install import install from distutils.cmd import Command +from wheel.bdist_wheel import bdist_wheel from torch.utils import cpp_extension from shutil import copyfile, rmtree @@ -17,13 +18,15 @@ __version__ = '0.0.2' -def build_libtrtorch_pre_cxx11_abi(develop=True): +def build_libtrtorch_pre_cxx11_abi(develop=True, use_dist_dir=True): cmd = ["/usr/bin/bazel", "build"] cmd.append("//cpp/api/lib:libtrtorch.so") if develop: cmd.append("--compilation_mode=dbg") else: cmd.append("--compilation_mode=opt") + if use_dist_dir: + cmd.append("--distdir=third_party/dist_dir/x86_64-linux-gnu") cmd.append("--config=python") print("building libtrtorch") @@ -41,12 +44,15 @@ def gen_version_file(): print("creating version file") f.write("__version__ = \"" + __version__ + '\"') -def copy_libtrtorch(): +def copy_libtrtorch(multilinux=False): if not os.path.exists(dir_path + '/trtorch/lib'): os.makedirs(dir_path + '/trtorch/lib') print("copying library into module") - copyfile(dir_path + "/../bazel-bin/cpp/api/lib/libtrtorch.so", dir_path + '/trtorch/lib/libtrtorch.so') + if multilinux: + copyfile(dir_path + "/build/libtrtorch_build/libtrtorch.so", dir_path + '/trtorch/lib/libtrtorch.so') + else: + copyfile(dir_path + "/../bazel-bin/cpp/api/lib/libtrtorch.so", dir_path + '/trtorch/lib/libtrtorch.so') class DevelopCommand(develop): description = "Builds the package and symlinks it into the PYTHONPATH" @@ -79,6 +85,21 @@ def run(self): copy_libtrtorch() install.run(self) +class BdistCommand(bdist_wheel): + description = "Builds the package" + + def initialize_options(self): + bdist_wheel.initialize_options(self) + + def finalize_options(self): + bdist_wheel.finalize_options(self) + + def run(self): + build_libtrtorch_pre_cxx11_abi(develop=False) + gen_version_file() + copy_libtrtorch() + bdist_wheel.run(self) + class CleanCommand(Command): """Custom clean command to tidy up the project root.""" PY_CLEAN_FILES = ['./build', './dist', './trtorch/__pycache__', './trtorch/lib', './*.pyc', './*.tgz', './*.egg-info'] @@ -114,9 +135,11 @@ def run(self): ], include_dirs=[ dir_path + "/../", + dir_path + "/../bazel-TRTorch/external/tensorrt/include", ], extra_compile_args=[ - "-D_GLIBCXX_USE_CXX11_ABI=0" + "-D_GLIBCXX_USE_CXX11_ABI=0", + "-Wno-deprecated-declaration", ], extra_link_args=[ "-D_GLIBCXX_USE_CXX11_ABI=0" @@ -128,6 +151,9 @@ def run(self): ) ] +with open("README.md", "r") as fh: + long_description = fh.read() + setup( name='trtorch', version=__version__, @@ -135,28 +161,46 @@ def run(self): author_email='narens@nvidia.com', url='https://nvidia.github.io/TRTorch', description='A compiler backend for PyTorch JIT targeting NVIDIA GPUs', - long_description='', + long_description_content_type='text/markdown', + long_description=long_description, ext_modules=ext_modules, - install_requires=['pybind11>=2.4'], - setup_requires=['pybind11>=2.4'], + install_requires=[ + 'torch==1.5.0', + ], + setup_requires=[], cmdclass={ 'install': InstallCommand, 'clean': CleanCommand, 'develop': DevelopCommand, - 'build_ext': cpp_extension.BuildExtension + 'build_ext': cpp_extension.BuildExtension, + 'bdist_wheel': BdistCommand, }, zip_safe=False, - license="BSD-3", + license="BSD", packages=find_packages(), - classifiers=["Intended Audience :: Developers", - "Intended Audience :: Science/Research", - "Operating System :: POSIX :: Linux", - "Programming Language :: C++", - "Programming Language :: Python", - "Programming Language :: Python :: Implementation :: CPython", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Artifical Intelligence", - "Topic :: Software Development", - "Topic :: Software Developement :: Libraries"], - + platform="Linux", + classifiers=[ + "Development Status :: 3 - Alpha", + "Environment :: GPU :: NVIDIA CUDA", + "License :: OSI Approved :: BSD License", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Operating System :: POSIX :: Linux", + "Programming Language :: C++", + "Programming Language :: Python", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries" + ], + python_requires='>=3.6', + include_package_data=True, + package_data={ + 'trtorch': ['lib/*.so'], + }, + exclude_package_data={ + '': ['*.cpp', '*.h'], + 'trtorch': ['csrc/*.cpp'], + } )