Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Add documentation for dynamo.compile backend #2389

Merged
merged 4 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ User Guide
* :ref:`creating_a_ts_mod`
* :ref:`getting_started_with_fx`
* :ref:`torch_compile`
* :ref:`dynamo_export`
* :ref:`ptq`
* :ref:`runtime`
* :ref:`saving_models`
Expand All @@ -56,6 +57,7 @@ User Guide
user_guide/creating_torchscript_module_in_python
user_guide/getting_started_with_fx_path
user_guide/torch_compile
user_guide/dynamo_export
user_guide/ptq
user_guide/runtime
user_guide/saving_models
Expand Down
82 changes: 82 additions & 0 deletions docsrc/user_guide/dynamo_export.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
.. _dynamo_export:

Torch-TensorRT Dynamo Backend
=============================================
.. currentmodule:: torch_tensorrt.dynamo

.. automodule:: torch_tensorrt.dynamo
:members:
:undoc-members:
:show-inheritance:

This guide presents Torch-TensorRT dynamo backend which optimizes Pytorch models
using TensorRT in an Ahead-Of-Time fashion.

Using the Dynamo backend
----------------------------------------
Pytorch 2.1 introduced ``torch.export`` APIs which
can export graphs from Pytorch programs into ``ExportedProgram`` objects. Torch-TensorRT dynamo
backend compiles these ``ExportedProgram`` objects and optimizes them using TensorRT. Here's a simple
usage of the dynamo backend

.. code-block:: python

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224), dtype=torch.float32).cuda()]
exp_program = torch.export.export(model, tuple(inputs))
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs) # Output is a torch.fx.GraphModule
trt_gm(*inputs)

.. note:: ``torch_tensorrt.dynamo.compile`` is the main API for users to interact with Torch-TensorRT dynamo backend. The input type of the model should be ``ExportedProgram`` (ideally the output of ``torch.export.export`` or ``torch_tensorrt.dynamo.trace`` (discussed in the section below)) and output type is a ``torch.fx.GraphModule`` object.

Customizeable Settings
----------------------

There are lot of options for users to customize their settings for optimizing with TensorRT.
Some of the frequently used options are as follows:

* ``inputs`` - For static shapes, this can be a list of torch tensors or `torch_tensorrt.Input` objects. For dynamic shapes, this should be a list of ``torch_tensorrt.Input`` objects.
* ``enabled_precisions`` - Set of precisions that TensorRT builder can use during optimization.
* ``truncate_long_and_double`` - Truncates long and double values to int and floats respectively.
* ``torch_executed_ops`` - Operators which are forced to be executed by Torch.
* ``min_block_size`` - Minimum number of consecutive operators required to be executed as a TensorRT segment.

The complete list of options can be found `here <https://github.com/pytorch/TensorRT/blob/123a486d6644a5bbeeec33e2f32257349acc0b8f/py/torch_tensorrt/dynamo/compile.py#L51-L77>`_

.. note:: We do not support INT precision currently in Dynamo. Support for this currently exists in
our Torchscript IR. We plan to implement similar support for dynamo in our next release.

Under the hood
--------------

Under the hood, ``torch_tensorrt.dynamo.compile`` performs the following on the graph.

* Lowering - Applies lowering passes to add/remove operators for optimal conversion.
* Partitioning - Partitions the graph into Pytorch and TensorRT segments based on the ``min_block_size`` and ``torch_executed_ops`` field.
* Conversion - Pytorch ops get converted into TensorRT ops in this phase.
* Optimization - Post conversion, we build the TensorRT engine and embed this inside the pytorch graph.

Tracing
-------

``torch_tensorrt.dynamo.trace`` can be used to trace a Pytorch graphs and produce ``ExportedProgram``.
This internally performs some decompositions of operators for downstream optimization.
The ``ExportedProgram`` can then be used with ``torch_tensorrt.dynamo.compile`` API.
If you have dynamic input shapes in your model, you can use this ``torch_tensorrt.dynamo.trace`` to export
the model with dynamic shapes. Alternatively, you can use ``torch.export`` `with constraints <https://pytorch.org/docs/stable/export.html#expressing-dynamism>`_ directly as well.

.. code-block:: python

import torch
import torch_tensorrt

inputs = [torch_tensorrt.Input(min_shape=(1, 3, 224, 224),
opt_shape=(4, 3, 224, 224),
max_shape=(8, 3, 224, 224),
dtype=torch.float32)]
model = MyModel().eval()
exp_program = torch_tensorrt.dynamo.trace(model, inputs)

19 changes: 16 additions & 3 deletions docsrc/user_guide/saving_models.rst
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
.. _runtime:
.. _saving_models:

Saving models compiled with Torch-TensorRT
====================================
.. currentmodule:: torch_tensorrt.dynamo

.. automodule:: torch_tensorrt.dynamo
:members:
:undoc-members:
:show-inheritance:

Saving models compiled with Torch-TensorRT varies slightly with the `ir` that has been used for compilation.

1) Dynamo IR
Dynamo IR
-------------

Starting with 2.1 release of Torch-TensorRT, we are switching the default compilation to be dynamo based.
The output of `ir=dynamo` compilation is a `torch.fx.GraphModule` object. There are two ways to save these objects

a) Converting to Torchscript
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

`torch.fx.GraphModule` objects cannot be serialized directly. Hence we use `torch.jit.trace` to convert this into a `ScriptModule` object which can be saved to disk.
The following code illustrates this approach.

Expand All @@ -30,6 +39,8 @@ The following code illustrates this approach.
model(inputs)

b) ExportedProgram
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant
`torch.fx.GraphModule` along with additional metadata can be used to create `ExportedProgram` which can be saved and loaded from disk.

Expand All @@ -56,7 +67,9 @@ This is needed as `torch._export` serialization cannot handle serializing and de

NOTE: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341

2) Torchscript IR

Torchscript IR
-------------

In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR.
This behavior stays the same in 2.X versions as well.
Expand Down
Loading