Skip to content

Commit

Permalink
feat: add convert_method_to_trt_engine() for dynamo (#2467)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Jan 17, 2024
1 parent fd19353 commit 5f66ade
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 15 deletions.
8 changes: 6 additions & 2 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,12 @@ def convert_method_to_trt_engine(
"convert_method_to_trt_engine call is not supported for ir=fx"
)
elif target_ir == _IRType.dynamo:
raise RuntimeError(
"convert_method_to_trt_engine call is not supported for ir=dynamo."
return torch_tensorrt.dynamo.convert_module_to_trt_engine( # type: ignore[no-any-return]
module,
inputs=inputs,
method_name=method_name,
enabled_precisions=enabled_precisions_set,
**kwargs,
)
elif target_ir == _IRType.torch_compile:
raise RuntimeError(
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
logger = logging.getLogger(__name__)

if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
from ._compiler import compile
from ._compiler import compile, convert_module_to_trt_engine
from ._exporter import export
from ._settings import CompilationSettings
from ._SourceIR import SourceIR
Expand Down
189 changes: 189 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from torch.export import ExportedProgram
from torch.fx.node import Target
from torch_tensorrt import _enums
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
EngineCapability,
Expand Down Expand Up @@ -47,7 +48,9 @@
)
from torch_tensorrt.dynamo.conversion import (
CompilationSettings,
UnsupportedOperatorException,
convert_module,
interpret_module_to_result,
repair_long_or_double_inputs,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
Expand Down Expand Up @@ -443,3 +446,189 @@ def compile_module(
dryrun_stats_display(dryrun_tracker, settings.dryrun)

return partitioned_module


def convert_module_to_trt_engine(
module: torch.fx.GraphModule,
method_name: str = "forward",
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None,
debug: bool = DEBUG,
workspace_size: int = WORKSPACE_SIZE,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Set[str] = set(),
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
version_compatible: bool = VERSION_COMPATIBLE,
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME,
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
device: Device = Device._current_device(),
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
disable_tf32: bool = DISABLE_TF32,
sparse_weights: bool = SPARSE_WEIGHTS,
refit: bool = REFIT,
engine_capability: EngineCapability = ENGINE_CAPABILITY,
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS,
dla_sram_size: int = DLA_SRAM_SIZE,
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE,
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE,
calibrator: object = None,
allow_shape_tensors: bool = False,
) -> bytes:
"""Convert a GraphModule module method to a serialized TensorRT engine
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings
Arguments:
module (torch.fx.GraphModule): Source module
Keyword Args:
inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum
to select device type. ::
input=[
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
torch_tensorrt.Input(
min_shape=(1, 224, 224, 3),
opt_shape=(1, 512, 512, 3),
max_shape=(1, 1024, 1024, 3),
dtype=torch.int32
format=torch.channel_last
), # Dynamic input shape for input #2
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
]
method_name (str): Name of method to convert
input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** ::
input_signature=([
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
torch_tensorrt.Input(
min_shape=(1, 224, 224, 3),
opt_shape=(1, 512, 512, 3),
max_shape=(1, 1024, 1024, 3),
dtype=torch.int32
format=torch.channel_last
), # Dynamic input shape for input #2
], torch.randn((1, 3, 224, 244))) # Use an example tensor and let torch_tensorrt infer settings for input #3
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
debug (bool): Whether to print out verbose debugging information
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
min_block_size (int): Minimum number of operators per TRT-Engine Block
torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage
pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False)
max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine
version_compatible (bool): Provide version forward-compatibility for engine plan files
optimization_level (Optional[int]): Builder optimization 0-5, higher levels imply longer build time,
searching for more optimization options. TRT defaults to 3
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
argument as None
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
or only a selected subset of them
device (Device): GPU to compile the model on
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
disable_tf32 (bool): Whether to disable TF32 computation for TRT layers
sparse_weights (bool): Whether to allow the builder to use sparse weights
refit (bool): Whether to build a refittable engine
engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
if debug:
set_log_level(logger.parent, logging.DEBUG)

input_list = list(inputs) if inputs is not None else []
# Prepare torch_trt inputs
input_list = prepare_inputs(input_list)
device = to_torch_tensorrt_device(device)

enabled_precisions = (
enabled_precisions if enabled_precisions is not None else {torch.float}
)

if (
torch.float16 in enabled_precisions
or torch_tensorrt.dtype.half in enabled_precisions
):
precision = torch.float16
elif (
torch.float32 in enabled_precisions
or torch_tensorrt.dtype.float in enabled_precisions
):
precision = torch.float32
elif len(enabled_precisions) == 0:
logger.info(f"No precision specified, defaulting to {PRECISION}")
precision = PRECISION
else:
raise ValueError(
f"Precision {enabled_precisions} not supported in the Dynamo Path"
)

compilation_options = {
"precision": precision,
"debug": debug,
"workspace_size": workspace_size,
"min_block_size": min_block_size,
"torch_executed_ops": torch_executed_ops,
"pass_through_build_failures": pass_through_build_failures,
"max_aux_streams": max_aux_streams,
"version_compatible": version_compatible,
"optimization_level": optimization_level,
"use_python_runtime": use_python_runtime,
"truncate_long_and_double": truncate_long_and_double,
"use_fast_partitioner": use_fast_partitioner,
"enable_experimental_decompositions": enable_experimental_decompositions,
"device": device,
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"refit": refit,
"engine_capability": engine_capability,
"num_avg_timing_iters": num_avg_timing_iters,
"dla_sram_size": dla_sram_size,
"dla_local_dram_size": dla_local_dram_size,
"dla_global_dram_size": dla_global_dram_size,
}

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
try:
interpreter_result = interpret_module_to_result(module, input_list, settings)
except UnsupportedOperatorException:
logger.error(
f"Conversion of module {module} not currently fully supported or convertible!",
exc_info=True,
)
except Exception as e:
logger.error(
f"While interpreting the module got an error: {e}",
exc_info=True,
)

import io

with io.BytesIO() as engine_bytes:
engine_bytes.write(interpreter_result.engine.serialize())
engine_bytearray = engine_bytes.getvalue()

return engine_bytearray
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import aten_ops_converters, ops_evaluators, prims_ops_converters
from ._conversion import convert_module
from ._conversion import convert_module, interpret_module_to_result
from ._ConversionContext import ConversionContext
from ._ConverterRegistry import * # noqa: F403
from ._TRTInterpreter import * # noqa: F403
Expand Down
41 changes: 30 additions & 11 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,29 @@
import torch
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
from torch_tensorrt.dynamo.conversion._TRTInterpreter import (
TRTInterpreter,
TRTInterpreterResult,
)
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_torch_inputs
from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device


def convert_module(
def interpret_module_to_result(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
settings: CompilationSettings = CompilationSettings(),
name: str = "",
) -> PythonTorchTensorRTModule | TorchTensorRTModule:
"""Convert an FX module to a TRT module
) -> TRTInterpreterResult:
"""Interpret an FX module to a TRTInterpreterResult
Args:
module: FX GraphModule to convert
module: FX GraphModule to interpret
inputs: Sequence of Tensors representing inputs to the module
settings: Compilation settings
name: TRT engine name
Returns:
_PythonTorchTensorRTModule or TorchTensorRTModule
TRTInterpreterResult
"""
# Specify module output data types to ensure TRT output types agree with
# that of the equivalent Torch module
torch_inputs = get_torch_inputs(inputs, settings.device)
module.to(to_torch_device(settings.device))
module_outputs = module(*torch_inputs)

if not isinstance(module_outputs, (list, tuple)):
Expand All @@ -54,6 +54,25 @@ def convert_module(
compilation_settings=settings,
)
interpreter_result = interpreter.run()
return interpreter_result


def convert_module(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
settings: CompilationSettings = CompilationSettings(),
name: str = "",
) -> PythonTorchTensorRTModule | TorchTensorRTModule:
"""Convert an FX module to a TRT module
Args:
module: FX GraphModule to convert
inputs: Sequence of Tensors representing inputs to the module
settings: Compilation settings
name: TRT engine name
Returns:
_PythonTorchTensorRTModule or TorchTensorRTModule
"""
interpreter_result = interpret_module_to_result(module, inputs, settings)

if settings.use_python_runtime:
return PythonTorchTensorRTModule(
Expand Down
47 changes: 47 additions & 0 deletions tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import unittest

import tensorrt as trt
import torch
import torch_tensorrt
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity


class TestConvertMethodToTrtEngine(unittest.TestCase):
def test_convert_module(self):
class Test(torch.nn.Module):
def forward(self, a, b):
return torch.add(a, b)

# Prepare the input data
input_data_0, input_data_1 = torch.randn((2, 4)), torch.randn((2, 4))

# Create a model
model = Test()
symbolic_traced_gm = torch.fx.symbolic_trace(model)

# Convert to TensorRT engine
trt_engine_str = torch_tensorrt.dynamo.convert_module_to_trt_engine(
symbolic_traced_gm, "forward", inputs=[input_data_0, input_data_1]
)

# Deserialize the TensorRT engine
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
engine = runtime.deserialize_cuda_engine(trt_engine_str)

# Inference on TRT Engine
py_trt_module = PythonTorchTensorRTModule(engine, ["a", "b"], ["output0"])
trt_output = py_trt_module(input_data_0, input_data_1).cpu()

# Inference on PyTorch model
model_output = model(input_data_0, input_data_1)

cos_sim = cosine_similarity(model_output, trt_output)
self.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5f66ade

Please sign in to comment.