From 458a4d1688d9675ef63bcf6574920cc8cfb6f744 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sat, 5 Oct 2024 22:34:12 -0700 Subject: [PATCH 01/33] skip run_shape_analysis --- py/torch_tensorrt/dynamo/_exporter.py | 15 +-- .../dynamo/conversion/_conversion.py | 96 +++++++++++++++++-- .../dynamo/runtime/_TorchTensorRTModule.py | 2 + 3 files changed, 97 insertions(+), 16 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index a7337f4f8e..a237cee8ab 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -16,7 +16,6 @@ OutputSpec, TensorArgument, ) -from torch_tensorrt.dynamo import partitioning def export( @@ -58,11 +57,9 @@ def transform( if kwarg_inputs is None: kwarg_inputs = {} gm = copy.deepcopy(gm) - # Run shape analysis - _, outputs_map = partitioning.run_shape_analysis(gm, inputs, kwarg_inputs) # Inline TensorRT submodules - inline_trt_modules(gm, outputs_map) + inline_trt_modules(gm) # Inline pytorch submodules inline_torch_modules(gm) @@ -361,9 +358,7 @@ def create_trt_exp_program( return trt_exp_program -def inline_trt_modules( - gm: torch.fx.GraphModule, outputs_map: Dict[Any, Sequence[Any]] -) -> torch.fx.GraphModule: +def inline_trt_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: """ Replace TRT submodules with trt engine nodes. """ @@ -379,7 +374,7 @@ def inline_trt_modules( trt_module_node = trt_module_node[0] assert trt_module_node.args - num_outputs = len(outputs_map[trt_module_node.name]) + num_outputs = len(trt_module.output_shapes) # Insert a call_function node to perform inference on TRT engine with gm.graph.inserting_before(trt_module_node): engine_name = f"{name}_engine" @@ -398,8 +393,8 @@ def inline_trt_modules( cast( FakeTensor, torch.empty_strided( - tuple(outputs_map[trt_module_node.name][idx]), - tuple([1] * len(outputs_map[trt_module_node.name][idx])), + tuple(trt_module.output_shapes[idx]), + tuple([1] * len(trt_module.output_shapes[idx])), ), ) ) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index f0b65b3a6e..76ea8cbccb 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -1,8 +1,9 @@ from __future__ import annotations import logging -from typing import Any, List, Optional, Sequence +from typing import Any, List, Optional, Sequence, Tuple +import tensorrt as trt import torch from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Device import Device @@ -16,13 +17,87 @@ TRTInterpreterResult, ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule -from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs - -import tensorrt as trt +from torch_tensorrt.dynamo.utils import ( + get_model_device, + get_torch_inputs, + unwrap_tensor_shape, +) logger = logging.getLogger(__name__) +def get_interpret_result( + module: torch.fx.GraphModule, + inputs: Sequence[Input], + output_dtypes: Sequence[dtype], + settings: CompilationSettings = CompilationSettings(), + engine_cache: Optional[BaseEngineCache] = None, +) -> TRTInterpreterResult: + interpreter = TRTInterpreter( + module, + inputs, + logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), + output_dtypes=output_dtypes, + compilation_settings=settings, + engine_cache=engine_cache, + ) + + interpreter_result = interpreter.run() + return interpreter_result + + +def infer_module_output_shapes_dtypes( + module: torch.fx.GraphModule, + inputs: Sequence[Input], + device: Device, + kwarg_inputs: Optional[dict[str, Any]] = None, + truncate_double: bool = False, +) -> Tuple[List[Tuple[int]], List[dtype]]: + """ + This function performs model inference to determine the output dtypes + and truncates them accordingly. inputs can be either arg_inputs or flattened input list. + If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input. + """ + # TODO: We can also determine output dtypes from the module.graph based on node metadata. + # However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata, + # so we stick to the model inference approach currently. + with unset_fake_temporarily(): + # Get the device on which the model exists + # For large models, this can be done on CPU to save GPU memory allocation for TRT. + device = get_model_device(module) + torch_inputs = get_torch_inputs(inputs, device) + if kwarg_inputs is None: + kwarg_inputs = {} + torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) + module_outputs = module(*torch_inputs, **torch_kwarg_inputs) + if not isinstance(module_outputs, (list, tuple)): + module_outputs = [module_outputs] + + # Int64 outputs can sometimes be generated from within other operators + # such as aten.sum - such outputs can be truncated + output_dtypes = [] + output_shapes = [] + for output in module_outputs: + output_ = output + # We don't need to check if output is nested here because the input module will be flattened + if not isinstance(output, torch.Tensor): + if isinstance(output, str): + raise ValueError( + f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)" + ) + else: + output_ = torch.tensor(output) + + output_shapes.append(unwrap_tensor_shape(output_)) + # output_shapes.append(output_.shape) + if truncate_double and output_.dtype == dtype.float64: + output_dtypes.append(dtype.float32) + else: + output_dtypes.append(dtype._from(output_.dtype)) + + return output_shapes, output_dtypes + + def infer_module_output_dtypes( module: torch.fx.GraphModule, inputs: Sequence[Input], @@ -139,8 +214,16 @@ def convert_module( Returns: PythonTorchTensorRTModule or TorchTensorRTModule """ - interpreter_result = interpret_module_to_result( - module, inputs, settings, engine_cache=engine_cache + + output_shapes, output_dtypes = infer_module_output_shapes_dtypes( + module, + inputs, + settings.device, + truncate_double=settings.truncate_double, + ) + + interpreter_result = get_interpret_result( + module, inputs, output_dtypes, settings, engine_cache=engine_cache ) rt_cls = PythonTorchTensorRTModule @@ -163,6 +246,7 @@ def convert_module( serialized_engine=interpreter_result.serialized_engine, input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), + output_shapes=output_shapes, name=name, settings=settings, weight_name_map=interpreter_result.weight_name_map, diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 7bf42da7f0..3c7f08c986 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -75,6 +75,7 @@ def __init__( serialized_engine: Optional[bytes] = None, input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, + output_shapes: Optional[List[Tuple[int]]] = None, *, name: str = "", settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed @@ -125,6 +126,7 @@ def __init__( self.output_binding_names = ( output_binding_names if output_binding_names is not None else [] ) + self.output_shapes = output_shapes self.name = name self.hardware_compatible = settings.hardware_compatible self.settings = copy.deepcopy(settings) From 2f408f9a156c9e6d84c94ab839ed4cd34cf070a9 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sat, 5 Oct 2024 23:14:37 -0700 Subject: [PATCH 02/33] test --- .../dynamo/runtime/_PythonTorchTensorRTModule.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index f74c239550..87334c166b 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -5,6 +5,7 @@ from tempfile import tempdir from typing import Any, Dict, List, Optional, Sequence, Tuple +import tensorrt as trt import torch import torch_tensorrt from torch.nn import Module @@ -19,8 +20,6 @@ multi_gpu_device_check, ) -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -36,6 +35,7 @@ def __init__( serialized_engine: Optional[bytes] = None, input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, + output_shapes: Optional[List[Tuple[int]]] = None, *, name: str = "", settings: CompilationSettings = CompilationSettings(), @@ -93,6 +93,7 @@ def __init__( self.output_names = ( output_binding_names if output_binding_names is not None else [] ) + self.output_shapes = output_shapes self.initialized = False self.target_device_id = ( settings.device.gpu_id From 1c5e86c5d3f491f0fd6c049eea41e796c1b16150 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sun, 6 Oct 2024 09:27:19 -0700 Subject: [PATCH 03/33] test --- py/torch_tensorrt/_compile.py | 2 +- py/torch_tensorrt/dynamo/_compiler.py | 2 +- py/torch_tensorrt/dynamo/_refit.py | 6 +- .../dynamo/conversion/_conversion.py | 90 ++----------------- tests/py/dynamo/conversion/harness.py | 8 +- 5 files changed, 18 insertions(+), 90 deletions(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index cd49962c17..b7aa7c680c 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -191,7 +191,7 @@ def compile( Returns: torch.nn.Module: Compiled Module, when run it will execute via TensorRT """ - + breakpoint() input_list = inputs if inputs is not None else [] enabled_precisions_set: Set[dtype | torch.dtype] = ( enabled_precisions diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 97aa2ec443..830ed10f95 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -672,7 +672,7 @@ def convert_exported_program_to_serialized_trt_engine( CONVERTERS.set_compilation_settings(settings) try: - interpreter_result = interpret_module_to_result( + _, interpreter_result = interpret_module_to_result( gm, inputs=flattened_input_list, arg_inputs=arg_input_list, diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 359dc0b3ff..c4abee880f 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -14,7 +14,9 @@ from torch_tensorrt.dynamo import partitioning from torch_tensorrt.dynamo._exporter import inline_torch_modules from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes +from torch_tensorrt.dynamo.conversion._conversion import ( + infer_module_output_shapes_dtypes, +) from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) @@ -72,7 +74,7 @@ def construct_refit_mapping( "CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]), } - output_dtypes = infer_module_output_dtypes( + _, output_dtypes = infer_module_output_shapes_dtypes( module, inputs, settings.device, diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 76ea8cbccb..d59d4160bc 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -26,26 +26,6 @@ logger = logging.getLogger(__name__) -def get_interpret_result( - module: torch.fx.GraphModule, - inputs: Sequence[Input], - output_dtypes: Sequence[dtype], - settings: CompilationSettings = CompilationSettings(), - engine_cache: Optional[BaseEngineCache] = None, -) -> TRTInterpreterResult: - interpreter = TRTInterpreter( - module, - inputs, - logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), - output_dtypes=output_dtypes, - compilation_settings=settings, - engine_cache=engine_cache, - ) - - interpreter_result = interpreter.run() - return interpreter_result - - def infer_module_output_shapes_dtypes( module: torch.fx.GraphModule, inputs: Sequence[Input], @@ -98,55 +78,6 @@ def infer_module_output_shapes_dtypes( return output_shapes, output_dtypes -def infer_module_output_dtypes( - module: torch.fx.GraphModule, - inputs: Sequence[Input], - device: Device, - kwarg_inputs: Optional[dict[str, Any]] = None, - truncate_double: bool = False, -) -> List[dtype]: - """ - This function performs model inference to determine the output dtypes - and truncates them accordingly. inputs can be either arg_inputs or flattened input list. - If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input. - """ - # TODO: We can also determine output dtypes from the module.graph based on node metadata. - # However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata, - # so we stick to the model inference approach currently. - with unset_fake_temporarily(): - # Get the device on which the model exists - # For large models, this can be done on CPU to save GPU memory allocation for TRT. - device = get_model_device(module) - torch_inputs = get_torch_inputs(inputs, device) - if kwarg_inputs is None: - kwarg_inputs = {} - torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) - module_outputs = module(*torch_inputs, **torch_kwarg_inputs) - if not isinstance(module_outputs, (list, tuple)): - module_outputs = [module_outputs] - - # Int64 outputs can sometimes be generated from within other operators - # such as aten.sum - such outputs can be truncated - output_dtypes = [] - for output in module_outputs: - output_ = output - # We don't need to check if output is nested here because the input module will be flattened - if not isinstance(output, torch.Tensor): - if isinstance(output, str): - raise ValueError( - f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)" - ) - else: - output_ = torch.tensor(output) - - if truncate_double and output_.dtype == dtype.float64: - output_dtypes.append(dtype.float32) - else: - output_dtypes.append(dtype._from(output_.dtype)) - - return output_dtypes - - def interpret_module_to_result( module: torch.fx.GraphModule, inputs: Sequence[Input], @@ -154,7 +85,7 @@ def interpret_module_to_result( arg_inputs: Optional[Sequence[Input]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, engine_cache: Optional[BaseEngineCache] = None, -) -> TRTInterpreterResult: +) -> Tuple[List[Tuple[int]], TRTInterpreterResult]: """Interpret an FX module to a TRTInterpreterResult Args: module: FX GraphModule to interpret @@ -165,10 +96,10 @@ def interpret_module_to_result( settings: Compilation settings engine_cache: Engine cache instance Returns: - TRTInterpreterResult + Output shapes, TRTInterpreterResult """ if arg_inputs is not None: - output_dtypes = infer_module_output_dtypes( + output_shapes, output_dtypes = infer_module_output_shapes_dtypes( module, arg_inputs, settings.device, @@ -177,7 +108,7 @@ def interpret_module_to_result( ) else: # args and kwargs are combined and flattened to one list - output_dtypes = infer_module_output_dtypes( + output_shapes, output_dtypes = infer_module_output_shapes_dtypes( module, inputs, settings.device, @@ -194,7 +125,7 @@ def interpret_module_to_result( ) interpreter_result = interpreter.run() - return interpreter_result + return output_shapes, interpreter_result def convert_module( @@ -215,15 +146,8 @@ def convert_module( PythonTorchTensorRTModule or TorchTensorRTModule """ - output_shapes, output_dtypes = infer_module_output_shapes_dtypes( - module, - inputs, - settings.device, - truncate_double=settings.truncate_double, - ) - - interpreter_result = get_interpret_result( - module, inputs, output_dtypes, settings, engine_cache=engine_cache + output_shapes, interpreter_result = interpret_module_to_result( + module, inputs, settings, engine_cache=engine_cache ) rt_cls = PythonTorchTensorRTModule diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 632b73e2f3..0a3918f178 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -16,7 +16,9 @@ # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter -from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes +from torch_tensorrt.dynamo.conversion._conversion import ( + infer_module_output_shapes_dtypes, +) from torch_tensorrt.dynamo.lowering import ( get_decompositions, post_lowering, @@ -310,7 +312,7 @@ def run_test( output_dtypes = None if check_dtype: - output_dtypes = infer_module_output_dtypes( + _, output_dtypes = infer_module_output_shapes_dtypes( mod, input_specs, compilation_settings.device, @@ -405,7 +407,7 @@ def run_test_with_dynamic_shape( ) if check_dtype: - output_dtypes = infer_module_output_dtypes( + _, output_dtypes = infer_module_output_shapes_dtypes( mod, input_specs, compilation_settings.device, From ba487dcbb9f69d38ce68362e67c7bd6eca3fc6b3 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sun, 6 Oct 2024 09:39:20 -0700 Subject: [PATCH 04/33] test --- py/torch_tensorrt/_compile.py | 1 - py/torch_tensorrt/dynamo/conversion/_conversion.py | 7 +++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index b7aa7c680c..9654350d4c 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -191,7 +191,6 @@ def compile( Returns: torch.nn.Module: Compiled Module, when run it will execute via TensorRT """ - breakpoint() input_list = inputs if inputs is not None else [] enabled_precisions_set: Set[dtype | torch.dtype] = ( enabled_precisions diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index d59d4160bc..1b745adfc6 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -34,7 +34,7 @@ def infer_module_output_shapes_dtypes( truncate_double: bool = False, ) -> Tuple[List[Tuple[int]], List[dtype]]: """ - This function performs model inference to determine the output dtypes + This function performs model inference to determine the output shapes and output dtypes and truncates them accordingly. inputs can be either arg_inputs or flattened input list. If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input. """ @@ -69,7 +69,6 @@ def infer_module_output_shapes_dtypes( output_ = torch.tensor(output) output_shapes.append(unwrap_tensor_shape(output_)) - # output_shapes.append(output_.shape) if truncate_double and output_.dtype == dtype.float64: output_dtypes.append(dtype.float32) else: @@ -86,7 +85,7 @@ def interpret_module_to_result( kwarg_inputs: Optional[dict[str, Any]] = None, engine_cache: Optional[BaseEngineCache] = None, ) -> Tuple[List[Tuple[int]], TRTInterpreterResult]: - """Interpret an FX module to a TRTInterpreterResult + """Interpret an FX module to the output shapes and a TRTInterpreterResult Args: module: FX GraphModule to interpret inputs: Sequence of FLATTENED Tensors representing inputs to the module. It should include both @@ -96,7 +95,7 @@ def interpret_module_to_result( settings: Compilation settings engine_cache: Engine cache instance Returns: - Output shapes, TRTInterpreterResult + (List[Tuple[int]], TRTInterpreterResult) """ if arg_inputs is not None: output_shapes, output_dtypes = infer_module_output_shapes_dtypes( From 2b43480a6471e3f4b74e7a945f57e8c8eb698f1c Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sun, 6 Oct 2024 09:40:45 -0700 Subject: [PATCH 05/33] test --- py/torch_tensorrt/_compile.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 9654350d4c..cd49962c17 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -191,6 +191,7 @@ def compile( Returns: torch.nn.Module: Compiled Module, when run it will execute via TensorRT """ + input_list = inputs if inputs is not None else [] enabled_precisions_set: Set[dtype | torch.dtype] = ( enabled_precisions From 3d94f8b22ce8ea0a60ee9b5586bc54df40e6b976 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sun, 13 Oct 2024 11:13:41 -0700 Subject: [PATCH 06/33] test --- py/torch_tensorrt/_compile.py | 23 +++++++++++++++---- py/torch_tensorrt/dynamo/_exporter.py | 12 ++-------- .../dynamo/models/test_export_kwargs_serde.py | 10 ++++---- tests/py/dynamo/models/test_export_serde.py | 16 ++++++------- tests/py/dynamo/models/test_model_refit.py | 4 ++-- .../runtime/test_002_lazy_engine_init.py | 4 +--- 6 files changed, 36 insertions(+), 33 deletions(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index cd49962c17..01394a5f8f 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -502,6 +502,10 @@ def save( "Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format" ) else: + if arg_inputs is not None: + raise ValueError( + "Provided model is a torch.jit.ScriptModule, do not allow user to provide inputs or arg_inputs." + ) torch.jit.save(module, file_path) elif module_type == _ModuleType.ep: if output_format == "torchscript": @@ -509,12 +513,13 @@ def save( "Provided model is a torch.export.ExportedProgram but the output_format specified is torchscript. Please verify the output_format" ) else: + if arg_inputs is not None: + raise ValueError( + "Provided model is a torch.export.ExportedProgram, do not allow user to provide inputs or arg_inputs during save, it should be provided during export and compile stage" + ) torch.export.save(module, file_path) elif module_type == _ModuleType.fx: - if arg_inputs is None: - raise ValueError( - "Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model" - ) + # The module type is torch.fx.GraphModule if output_format == "torchscript": module_ts = torch.jit.trace( @@ -525,11 +530,19 @@ def save( if not retrace: from torch_tensorrt.dynamo._exporter import export - exp_program = export(module, arg_inputs, kwarg_inputs) + if arg_inputs is not None: + raise ValueError( + "Provided model is a torch.fx.GraphModule and retrace is False, do not allow user to provide inputs or arg_inputs." + ) + exp_program = export(module) torch.export.save(exp_program, file_path) else: from torch._higher_order_ops.torchbind import enable_torchbind_tracing + if arg_inputs is None: + raise ValueError( + "Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model" + ) with enable_torchbind_tracing(): exp_program = torch.export.export( module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index a237cee8ab..b441c16849 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -1,6 +1,6 @@ import copy import operator -from typing import Any, Dict, Optional, Sequence, Tuple, cast +from typing import Any, Dict, Sequence, Tuple, cast import torch from torch._guards import detect_fake_mode @@ -20,8 +20,6 @@ def export( gm: torch.fx.GraphModule, - inputs: Sequence[torch.Tensor], - kwarg_inputs: Optional[dict[str, Any]] = None, ) -> ExportedProgram: """Export the result of TensorRT compilation into the desired output format. @@ -29,17 +27,13 @@ def export( gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile`` inputs (torch.Tensor): Torch input tensors """ - if kwarg_inputs is None: - kwarg_inputs = {} - patched_module = transform(gm, inputs, kwarg_inputs) + patched_module = transform(gm) exp_program = create_trt_exp_program(patched_module) return exp_program def transform( gm: torch.fx.GraphModule, - inputs: Sequence[torch.Tensor], - kwarg_inputs: Optional[dict[str, Any]] = None, ) -> torch.fx.GraphModule: """ Transforms the graphmodule by inlining Pytorch and TensorRT submodules. @@ -54,8 +48,6 @@ def transform( """ # Make a copy the graph since this function transforms the input graph and changes it's attributes. # This transformed graph is meant to be consumed by `create_trt_exp_program` - if kwarg_inputs is None: - kwarg_inputs = {} gm = copy.deepcopy(gm) # Inline TensorRT submodules diff --git a/tests/py/dynamo/models/test_export_kwargs_serde.py b/tests/py/dynamo/models/test_export_kwargs_serde.py index 91ee59c0f4..aa4ea14cea 100644 --- a/tests/py/dynamo/models/test_export_kwargs_serde.py +++ b/tests/py/dynamo/models/test_export_kwargs_serde.py @@ -77,7 +77,7 @@ def forward(self, x, b=5, c=None, d=None): # Save the module trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwarg_inputs=kwargs) + torchtrt.save(trt_gm, trt_ep_path) # Clean up model env torch._dynamo.reset() @@ -138,7 +138,7 @@ def forward(self, x, b=5, c=None, d=None): # Save the module trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwarg_inputs=kwargs) + torchtrt.save(trt_gm, trt_ep_path) # Clean up model env torch._dynamo.reset() @@ -208,7 +208,7 @@ def forward(self, x, b=5, c=None, d=None): # Save the module trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwarg_inputs=kwargs) + torchtrt.save(trt_gm, trt_ep_path) # Clean up model env torch._dynamo.reset() @@ -297,7 +297,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]): ) # Save the module trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwarg_inputs=kwargs) + torchtrt.save(trt_gm, trt_ep_path) # Clean up model env torch._dynamo.reset() @@ -386,7 +386,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]): ) # Save the module trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwarg_inputs=kwargs) + torchtrt.save(trt_gm, trt_ep_path) # Clean up model env torch._dynamo.reset() diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index 470da496ba..4c0b9c6d06 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -48,7 +48,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, inputs=[input]) + torchtrt.save(trt_module, trt_ep_path) deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs @@ -102,7 +102,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, inputs=[input]) + torchtrt.save(trt_module, trt_ep_path) deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs @@ -160,7 +160,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, inputs=[input]) + torchtrt.save(trt_module, trt_ep_path) deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs @@ -221,7 +221,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, inputs=[input]) + torchtrt.save(trt_module, trt_ep_path) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -264,7 +264,7 @@ def test_resnet18(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, inputs=[input]) + torchtrt.save(trt_module, trt_ep_path) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -309,7 +309,7 @@ def test_resnet18_dynamic(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, inputs=[input]) + torchtrt.save(trt_module, trt_ep_path) # TODO: Enable this serialization issues are fixed # deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -359,7 +359,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, inputs=[input]) + torchtrt.save(trt_module, trt_ep_path) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -417,7 +417,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, inputs=[input]) + torchtrt.save(trt_module, trt_ep_path) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 0f6fb05914..3da840b8fc 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -322,7 +322,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap(): min_block_size=min_block_size, make_refittable=True, ) - torchtrt.save(trt_gm, trt_ep_path, inputs=inputs) + torchtrt.save(trt_gm, trt_ep_path) trt_gm = torch.export.load(trt_ep_path) new_trt_gm = refit_module_weights( compiled_module=trt_gm, @@ -592,7 +592,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap(): min_block_size=min_block_size, make_refittable=True, ) - torchtrt.save(trt_gm, trt_ep_path, inputs=inputs) + torchtrt.save(trt_gm, trt_ep_path) trt_gm = torch.export.load(trt_ep_path) new_trt_gm = refit_module_weights( compiled_module=trt_gm, diff --git a/tests/py/dynamo/runtime/test_002_lazy_engine_init.py b/tests/py/dynamo/runtime/test_002_lazy_engine_init.py index aafd099bde..da0dce8f44 100644 --- a/tests/py/dynamo/runtime/test_002_lazy_engine_init.py +++ b/tests/py/dynamo/runtime/test_002_lazy_engine_init.py @@ -239,9 +239,7 @@ def test_lazy_engine_init_cpp_serialization(self): trt_mod = torchtrt.compile(model, **compile_spec) with tempfile.TemporaryDirectory() as tmpdir: - torch_tensorrt.save( - trt_mod, os.path.join(tmpdir, "tmp_trt_mod.ep"), inputs=[input] - ) + torch_tensorrt.save(trt_mod, os.path.join(tmpdir, "tmp_trt_mod.ep")) new_trt_mod = torch.export.load(os.path.join(tmpdir, "tmp_trt_mod.ep")) loaded_trt_mod = new_trt_mod.module() From b89cbe00846862537e5ad648bc487b026c51f320 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 14 Oct 2024 22:02:03 -0700 Subject: [PATCH 07/33] resolve comments --- py/torch_tensorrt/_compile.py | 12 ++++++------ py/torch_tensorrt/dynamo/_compiler.py | 2 +- py/torch_tensorrt/dynamo/_refit.py | 6 ++---- py/torch_tensorrt/dynamo/conversion/_conversion.py | 12 ++++++------ tests/py/dynamo/conversion/harness.py | 8 +++----- 5 files changed, 18 insertions(+), 22 deletions(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 01394a5f8f..4726be7141 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -503,8 +503,8 @@ def save( ) else: if arg_inputs is not None: - raise ValueError( - "Provided model is a torch.jit.ScriptModule, do not allow user to provide inputs or arg_inputs." + logger.warning( + "Provided model is a torch.jit.ScriptModule, inputs or arg_inputs is not necessary during save." ) torch.jit.save(module, file_path) elif module_type == _ModuleType.ep: @@ -514,8 +514,8 @@ def save( ) else: if arg_inputs is not None: - raise ValueError( - "Provided model is a torch.export.ExportedProgram, do not allow user to provide inputs or arg_inputs during save, it should be provided during export and compile stage" + logger.warning( + "Provided model is a torch.export.ExportedProgram, inputs or arg_inputs is not necessary during save, it uses the inputs or arg_inputs provided during export and compile" ) torch.export.save(module, file_path) elif module_type == _ModuleType.fx: @@ -531,8 +531,8 @@ def save( from torch_tensorrt.dynamo._exporter import export if arg_inputs is not None: - raise ValueError( - "Provided model is a torch.fx.GraphModule and retrace is False, do not allow user to provide inputs or arg_inputs." + logger.warning( + "Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save." ) exp_program = export(module) torch.export.save(exp_program, file_path) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 19357937fe..92bab2f304 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -700,7 +700,7 @@ def convert_exported_program_to_serialized_trt_engine( CONVERTERS.set_compilation_settings(settings) try: - _, interpreter_result = interpret_module_to_result( + interpreter_result, _ = interpret_module_to_result( gm, inputs=flattened_input_list, arg_inputs=arg_input_list, diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index c4abee880f..ac3827f32d 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -14,9 +14,7 @@ from torch_tensorrt.dynamo import partitioning from torch_tensorrt.dynamo._exporter import inline_torch_modules from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.conversion._conversion import ( - infer_module_output_shapes_dtypes, -) +from torch_tensorrt.dynamo.conversion._conversion import infer_module_outputs from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) @@ -74,7 +72,7 @@ def construct_refit_mapping( "CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]), } - _, output_dtypes = infer_module_output_shapes_dtypes( + _, output_dtypes = infer_module_outputs( module, inputs, settings.device, diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 1b745adfc6..7e9ed6206b 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) -def infer_module_output_shapes_dtypes( +def infer_module_outputs( module: torch.fx.GraphModule, inputs: Sequence[Input], device: Device, @@ -84,7 +84,7 @@ def interpret_module_to_result( arg_inputs: Optional[Sequence[Input]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, engine_cache: Optional[BaseEngineCache] = None, -) -> Tuple[List[Tuple[int]], TRTInterpreterResult]: +) -> Tuple[TRTInterpreterResult, List[Tuple[int]]]: """Interpret an FX module to the output shapes and a TRTInterpreterResult Args: module: FX GraphModule to interpret @@ -95,10 +95,10 @@ def interpret_module_to_result( settings: Compilation settings engine_cache: Engine cache instance Returns: - (List[Tuple[int]], TRTInterpreterResult) + (TRTInterpreterResult, List[Tuple[int]]) """ if arg_inputs is not None: - output_shapes, output_dtypes = infer_module_output_shapes_dtypes( + output_shapes, output_dtypes = infer_module_outputs( module, arg_inputs, settings.device, @@ -107,7 +107,7 @@ def interpret_module_to_result( ) else: # args and kwargs are combined and flattened to one list - output_shapes, output_dtypes = infer_module_output_shapes_dtypes( + output_shapes, output_dtypes = infer_module_outputs( module, inputs, settings.device, @@ -145,7 +145,7 @@ def convert_module( PythonTorchTensorRTModule or TorchTensorRTModule """ - output_shapes, interpreter_result = interpret_module_to_result( + interpreter_result, output_shapes = interpret_module_to_result( module, inputs, settings, engine_cache=engine_cache ) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 0a3918f178..35bf575115 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -16,9 +16,7 @@ # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter -from torch_tensorrt.dynamo.conversion._conversion import ( - infer_module_output_shapes_dtypes, -) +from torch_tensorrt.dynamo.conversion._conversion import infer_module_outputs from torch_tensorrt.dynamo.lowering import ( get_decompositions, post_lowering, @@ -312,7 +310,7 @@ def run_test( output_dtypes = None if check_dtype: - _, output_dtypes = infer_module_output_shapes_dtypes( + _, output_dtypes = infer_module_outputs( mod, input_specs, compilation_settings.device, @@ -407,7 +405,7 @@ def run_test_with_dynamic_shape( ) if check_dtype: - _, output_dtypes = infer_module_output_shapes_dtypes( + _, output_dtypes = infer_module_outputs( mod, input_specs, compilation_settings.device, From 3eb48d786d403b12bd3700004c60e08c5c002f7b Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 15 Oct 2024 22:10:59 -0700 Subject: [PATCH 08/33] test --- py/torch_tensorrt/dynamo/conversion/_conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 7e9ed6206b..195cec1b47 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -124,7 +124,7 @@ def interpret_module_to_result( ) interpreter_result = interpreter.run() - return output_shapes, interpreter_result + return interpreter_result, output_shapes def convert_module( From 50eb0d898eb2dd23a1ea945995b36ddc71781bb5 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sat, 19 Oct 2024 20:57:57 -0700 Subject: [PATCH 09/33] replace dummy inference --- py/torch_tensorrt/dynamo/_compiler.py | 39 +++++- py/torch_tensorrt/dynamo/_exporter.py | 20 +-- py/torch_tensorrt/dynamo/_refit.py | 8 +- py/torch_tensorrt/dynamo/_tracer.py | 5 + .../dynamo/conversion/_conversion.py | 132 +++++++++--------- .../runtime/_PythonTorchTensorRTModule.py | 2 - .../dynamo/runtime/_TorchTensorRTModule.py | 2 - tests/py/dynamo/conversion/harness.py | 18 +-- tests/py/dynamo/conversion/test_acos_aten.py | 2 +- tests/py/dynamo/conversion/test_acosh_aten.py | 4 +- tests/py/dynamo/conversion/test_any.py | 8 +- .../py/dynamo/conversion/test_arange_aten.py | 1 + tests/py/dynamo/conversion/test_cat_aten.py | 2 + tests/py/dynamo/conversion/test_full_aten.py | 6 +- tests/py/dynamo/conversion/test_isinf_aten.py | 11 +- tests/py/dynamo/conversion/test_isnan_aten.py | 13 +- .../conversion/test_logical_and_aten.py | 2 +- 17 files changed, 168 insertions(+), 107 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 92bab2f304..5cde0b9a15 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -421,6 +421,12 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: if not settings.use_fast_partitioner: dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module)) + submodule_node_dict = {} + for node in partitioned_module.graph.nodes: + if "_run_on_acc" not in node.name: + continue + submodule_node_dict[node.name] = node + # Store TRT replicas of Torch subgraphs trt_modules = {} # Iterate over all components that can be accelerated @@ -440,6 +446,37 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: ) continue + # set the submodule meta val back to the parent trt_module_node + outputs = [node for node in submodule.graph.nodes if node.op == "output"] + outputs = outputs[0].args + outputs_meta_val = [] + for ele in outputs: + # it can be a torch.fx.node.Node or a tuple of torch.fx.node.Node + if isinstance(ele, torch.fx.node.Node): + if "val" not in ele.meta: + raise ValueError( + "expect submodule output node has meta['val'] info" + ) + outputs_meta_val.append(ele.meta["val"]) + elif isinstance(ele, tuple): + for node in ele: + if isinstance(node, torch.fx.node.Node): + if "val" not in ele.meta: + raise ValueError( + "expect submodule output node has meta['val'] info" + ) + outputs_meta_val.append(node.meta["val"]) + else: + raise ValueError(f"not expected types: {type(node)=}") + else: + raise ValueError(f"not expected types: {type(ele)=}") + + if name not in submodule_node_dict: + raise ValueError( + f"node_name: {name} does not exist in the submodule node dictionary" + ) + submodule_node_dict[name].meta["val"] = outputs_meta_val + subgraph_data = PerSubgraphData() subgraph_data.subgraph_name = name subgraph_data.subgraph_op_count = len( @@ -700,7 +737,7 @@ def convert_exported_program_to_serialized_trt_engine( CONVERTERS.set_compilation_settings(settings) try: - interpreter_result, _ = interpret_module_to_result( + interpreter_result = interpret_module_to_result( gm, inputs=flattened_input_list, arg_inputs=arg_input_list, diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index b441c16849..d8df7fac5a 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -366,7 +366,11 @@ def inline_trt_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: trt_module_node = trt_module_node[0] assert trt_module_node.args - num_outputs = len(trt_module.output_shapes) + if "val" not in trt_module_node.meta: + raise ValueError( + f"trt_module_node: {trt_module_node.name} does not have the meta['val'] info" + ) + num_outputs = len(trt_module_node.meta["val"]) # Insert a call_function node to perform inference on TRT engine with gm.graph.inserting_before(trt_module_node): engine_name = f"{name}_engine" @@ -377,19 +381,9 @@ def inline_trt_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: torch.ops.tensorrt.execute_engine.default, (trt_module_node.args, engine_node), ) - trt_node.meta["val"] = [] + # set trt_node.meta with trt_module_node.meta assert num_outputs > 0 - # Generate meta data for TRT node (a FakeTensor with corresponding output shape) - for idx in range(num_outputs): - trt_node.meta["val"].append( - cast( - FakeTensor, - torch.empty_strided( - tuple(trt_module.output_shapes[idx]), - tuple([1] * len(trt_module.output_shapes[idx])), - ), - ) - ) + trt_node.meta["val"] = trt_module_node.meta["val"] # meta["val"] should be a lighter version of a tensor. For eg: it should be a FakeTensor (with output shape and dtype properties) # Lighter version of a custom_obj is not defined clearly. meta["val"] does not have any type expectations but diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index ac3827f32d..15e2d2705f 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -14,7 +14,7 @@ from torch_tensorrt.dynamo import partitioning from torch_tensorrt.dynamo._exporter import inline_torch_modules from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.conversion._conversion import infer_module_outputs +from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) @@ -72,10 +72,10 @@ def construct_refit_mapping( "CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]), } - _, output_dtypes = infer_module_outputs( + output_dtypes = infer_module_output_dtypes( module, - inputs, - settings.device, + # inputs, + # settings.device, truncate_double=settings.truncate_double, ) diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 78f7989777..2fb2b3080d 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -115,6 +115,9 @@ def get_dynamic_shapes_args(mod: torch.nn.Module, inputs: Any) -> dict[str, Any] args = list(signature(mod.forward).parameters.keys()) dynamic_shapes = {} for input, input_name in zip(inputs, args[: len(inputs)]): + # if input.name is not None, also not empty str, use the input.name + if input.name is not None and len(input.name) > 0 and input.name != input_name: + input_name = input.name dynamic_shapes[input_name] = get_dynamic_shapes(input) return dynamic_shapes @@ -131,11 +134,13 @@ def get_dynamic_shapes(input: Input) -> dict[Any, Any]: max_shape = input.shape["max_shape"] assert len(min_shape) == len(opt_shape) == len(max_shape) for dim in range(len(min_shape)): + # reverse_dim = len(min_shape)-1 - dim if min_shape[dim] == opt_shape[dim] == max_shape[dim]: continue else: dynamic_dims[dim] = Dim( input.name + "_" + str(dim), + # input.name + "_" + str(reverse_dim), min=min_shape[dim], max=max_shape[dim], ) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 195cec1b47..dea1aef0a9 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -1,12 +1,11 @@ from __future__ import annotations import logging -from typing import Any, List, Optional, Sequence, Tuple +from typing import Any, List, Optional, Sequence import tensorrt as trt import torch -from torch.fx.experimental.proxy_tensor import unset_fake_temporarily -from torch_tensorrt._Device import Device +from torch._subclasses.fake_tensor import FakeTensor from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input @@ -17,64 +16,83 @@ TRTInterpreterResult, ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule -from torch_tensorrt.dynamo.utils import ( - get_model_device, - get_torch_inputs, - unwrap_tensor_shape, -) + +# from torch_tensorrt.dynamo.utils import ( +# get_model_device, +# get_torch_inputs, +# unwrap_tensor_shape, +# ) logger = logging.getLogger(__name__) -def infer_module_outputs( +def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]: + output_dtypes = [] + if isinstance(output, torch.fx.node.Node): + if "val" in output.meta: + output_meta = output.meta["val"] + if isinstance(output_meta, (FakeTensor, torch.Tensor)): + if truncate_doulbe and output_meta.dtype == torch.float64: + output_dtypes.append(dtype.float32) + else: + output_dtypes.append(dtype._from(output_meta.dtype)) + else: + raise ValueError( + "meta['val'] does not exist, expect meta['val'] exists for each output node" + ) + elif isinstance(output, tuple): + for ele in output: + output_dtypes.extend(get_output_dtypes(ele)) + else: + raise ValueError( + f"got type {type(output)}, expected type is a torch.fx.node.Node or a tuple of torch.fx.node.Node" + ) + return output_dtypes + + +def infer_module_output_dtypes( module: torch.fx.GraphModule, - inputs: Sequence[Input], - device: Device, - kwarg_inputs: Optional[dict[str, Any]] = None, + # inputs: Sequence[Input], + # device: Device, + # kwarg_inputs: Optional[dict[str, Any]] = None, truncate_double: bool = False, -) -> Tuple[List[Tuple[int]], List[dtype]]: +) -> List[dtype]: """ This function performs model inference to determine the output shapes and output dtypes and truncates them accordingly. inputs can be either arg_inputs or flattened input list. If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input. """ + outputs = [node for node in module.graph.nodes if node.op == "output"] + outputs = outputs[0].args + return get_output_dtypes(outputs, truncate_double) + # TODO: We can also determine output dtypes from the module.graph based on node metadata. # However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata, # so we stick to the model inference approach currently. - with unset_fake_temporarily(): - # Get the device on which the model exists - # For large models, this can be done on CPU to save GPU memory allocation for TRT. - device = get_model_device(module) - torch_inputs = get_torch_inputs(inputs, device) - if kwarg_inputs is None: - kwarg_inputs = {} - torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) - module_outputs = module(*torch_inputs, **torch_kwarg_inputs) - if not isinstance(module_outputs, (list, tuple)): - module_outputs = [module_outputs] + # with unset_fake_temporarily(): + # # Get the device on which the model exists + # # For large models, this can be done on CPU to save GPU memory allocation for TRT. + # device = get_model_device(module) + # torch_inputs = get_torch_inputs(inputs, device) + # if kwarg_inputs is None: + # kwarg_inputs = {} + # torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) + # module_outputs = module(*torch_inputs, **torch_kwarg_inputs) + # if not isinstance(module_outputs, (list, tuple)): + # module_outputs = [module_outputs] # Int64 outputs can sometimes be generated from within other operators - # such as aten.sum - such outputs can be truncated - output_dtypes = [] - output_shapes = [] - for output in module_outputs: - output_ = output - # We don't need to check if output is nested here because the input module will be flattened - if not isinstance(output, torch.Tensor): - if isinstance(output, str): - raise ValueError( - f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)" - ) - else: - output_ = torch.tensor(output) - - output_shapes.append(unwrap_tensor_shape(output_)) - if truncate_double and output_.dtype == dtype.float64: - output_dtypes.append(dtype.float32) - else: - output_dtypes.append(dtype._from(output_.dtype)) + # # such as aten.sum - such outputs can be truncated + # output_dtypes_ret = [] + + # for output_dtype in output_dtypes: + + # if truncate_double and output_dtype == dtype.float64: + # output_dtypes_ret.append(dtype.float32) + # else: + # output_dtypes_ret.append(dtype._from(output_dtype)) - return output_shapes, output_dtypes + # return output_shapes, output_dtypes def interpret_module_to_result( @@ -84,7 +102,7 @@ def interpret_module_to_result( arg_inputs: Optional[Sequence[Input]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, engine_cache: Optional[BaseEngineCache] = None, -) -> Tuple[TRTInterpreterResult, List[Tuple[int]]]: +) -> TRTInterpreterResult: """Interpret an FX module to the output shapes and a TRTInterpreterResult Args: module: FX GraphModule to interpret @@ -97,22 +115,9 @@ def interpret_module_to_result( Returns: (TRTInterpreterResult, List[Tuple[int]]) """ - if arg_inputs is not None: - output_shapes, output_dtypes = infer_module_outputs( - module, - arg_inputs, - settings.device, - kwarg_inputs=kwarg_inputs, - truncate_double=settings.truncate_double, - ) - else: - # args and kwargs are combined and flattened to one list - output_shapes, output_dtypes = infer_module_outputs( - module, - inputs, - settings.device, - truncate_double=settings.truncate_double, - ) + output_dtypes = infer_module_output_dtypes( + module, truncate_double=settings.truncate_double + ) interpreter = TRTInterpreter( module, @@ -124,7 +129,7 @@ def interpret_module_to_result( ) interpreter_result = interpreter.run() - return interpreter_result, output_shapes + return interpreter_result def convert_module( @@ -145,7 +150,7 @@ def convert_module( PythonTorchTensorRTModule or TorchTensorRTModule """ - interpreter_result, output_shapes = interpret_module_to_result( + interpreter_result = interpret_module_to_result( module, inputs, settings, engine_cache=engine_cache ) @@ -169,7 +174,6 @@ def convert_module( serialized_engine=interpreter_result.serialized_engine, input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), - output_shapes=output_shapes, name=name, settings=settings, weight_name_map=interpreter_result.weight_name_map, diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 6c043bc803..1f84b7c400 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -35,7 +35,6 @@ def __init__( serialized_engine: Optional[bytes] = None, input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, - output_shapes: Optional[List[Tuple[int]]] = None, *, name: str = "", settings: CompilationSettings = CompilationSettings(), @@ -93,7 +92,6 @@ def __init__( self.output_names = ( output_binding_names if output_binding_names is not None else [] ) - self.output_shapes = output_shapes self.initialized = False self.target_device_id = ( settings.device.gpu_id diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 3c7f08c986..7bf42da7f0 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -75,7 +75,6 @@ def __init__( serialized_engine: Optional[bytes] = None, input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, - output_shapes: Optional[List[Tuple[int]]] = None, *, name: str = "", settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed @@ -126,7 +125,6 @@ def __init__( self.output_binding_names = ( output_binding_names if output_binding_names is not None else [] ) - self.output_shapes = output_shapes self.name = name self.hardware_compatible = settings.hardware_compatible self.settings = copy.deepcopy(settings) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 35bf575115..a692811c44 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -16,7 +16,7 @@ # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter -from torch_tensorrt.dynamo.conversion._conversion import infer_module_outputs +from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes from torch_tensorrt.dynamo.lowering import ( get_decompositions, post_lowering, @@ -259,7 +259,7 @@ def run_test( atol=ATOL, precision=dtype.f32, check_dtype=True, - use_dynamo_tracer=False, + use_dynamo_tracer=True, enable_passes=False, propagate_shapes=False, int32_reqd=False, @@ -310,10 +310,10 @@ def run_test( output_dtypes = None if check_dtype: - _, output_dtypes = infer_module_outputs( + output_dtypes = infer_module_output_dtypes( mod, - input_specs, - compilation_settings.device, + # input_specs, + # compilation_settings.device, truncate_double=compilation_settings.truncate_double, ) @@ -382,7 +382,7 @@ def run_test_with_dynamic_shape( rtol=RTOL, atol=ATOL, output_dtypes=None, - use_dynamo_tracer=False, + use_dynamo_tracer=True, enable_passes=False, use_example_tensors=True, pyt_inputs=None, @@ -405,10 +405,10 @@ def run_test_with_dynamic_shape( ) if check_dtype: - _, output_dtypes = infer_module_outputs( + output_dtypes = infer_module_output_dtypes( mod, - input_specs, - compilation_settings.device, + # input_specs, + # compilation_settings.device, truncate_double=compilation_settings.truncate_double, ) diff --git a/tests/py/dynamo/conversion/test_acos_aten.py b/tests/py/dynamo/conversion/test_acos_aten.py index 81b83bcc4a..8e93e0a309 100644 --- a/tests/py/dynamo/conversion/test_acos_aten.py +++ b/tests/py/dynamo/conversion/test_acos_aten.py @@ -66,7 +66,7 @@ def forward(self, input): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, diff --git a/tests/py/dynamo/conversion/test_acosh_aten.py b/tests/py/dynamo/conversion/test_acosh_aten.py index 090756ddfb..dd19c188c5 100644 --- a/tests/py/dynamo/conversion/test_acosh_aten.py +++ b/tests/py/dynamo/conversion/test_acosh_aten.py @@ -58,7 +58,7 @@ def forward(self, input): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, @@ -66,7 +66,7 @@ def forward(self, input): ( "3d_dim_dtype_int32", (1, 1, 1), - (1, 2, 4), + (2, 2, 4), (2, 3, 5), torch.int32, torch.float, diff --git a/tests/py/dynamo/conversion/test_any.py b/tests/py/dynamo/conversion/test_any.py index 1d1fc634ef..75620f7c34 100644 --- a/tests/py/dynamo/conversion/test_any.py +++ b/tests/py/dynamo/conversion/test_any.py @@ -191,7 +191,7 @@ class TestAnyConverterDynamic(DispatchTestCase): ( "3d_dynamic_float", (2, 1, 1), - (2, 2, 1), + (2, 2, 2), (3, 2, 4), torch.float, ), @@ -234,7 +234,7 @@ def forward(self, x): ( "3d_dynamic_dim_float", (2, 1, 1), - (2, 2, 1), + (2, 2, 2), (3, 2, 4), torch.float, 2, @@ -252,7 +252,7 @@ def forward(self, x): ( "3d_dynamic_dim_bool", (2, 1, 1), - (2, 2, 1), + (2, 2, 2), (3, 2, 4), torch.bool, 0, @@ -285,7 +285,7 @@ def forward(self, x): ( "3d_dynamic_dims_float", (2, 1, 1), - (2, 2, 1), + (2, 2, 2), (3, 2, 4), torch.float, [1, 2], diff --git a/tests/py/dynamo/conversion/test_arange_aten.py b/tests/py/dynamo/conversion/test_arange_aten.py index 1e1e9b3cc7..cb3e4c6b51 100644 --- a/tests/py/dynamo/conversion/test_arange_aten.py +++ b/tests/py/dynamo/conversion/test_arange_aten.py @@ -56,6 +56,7 @@ def forward(self, end_tensor): use_example_tensors=False, check_dtype=False, pyt_inputs=[pyt_input], + use_dynamo_tracer=False, ) diff --git a/tests/py/dynamo/conversion/test_cat_aten.py b/tests/py/dynamo/conversion/test_cat_aten.py index a9e4a45c81..abe96af271 100644 --- a/tests/py/dynamo/conversion/test_cat_aten.py +++ b/tests/py/dynamo/conversion/test_cat_aten.py @@ -42,12 +42,14 @@ def forward(self, x, y): min_shape=(16, 2, 3), opt_shape=(16, 3, 3), max_shape=(16, 32, 3), + name="x", ), Input( dtype=torch.float32, min_shape=(16, 2, 3), opt_shape=(16, 16, 3), max_shape=(16, 32, 3), + name="y", ), ] self.run_test_with_dynamic_shape( diff --git a/tests/py/dynamo/conversion/test_full_aten.py b/tests/py/dynamo/conversion/test_full_aten.py index 29b48d1451..44d30af43e 100644 --- a/tests/py/dynamo/conversion/test_full_aten.py +++ b/tests/py/dynamo/conversion/test_full_aten.py @@ -50,7 +50,11 @@ def forward(self, shape): ) ] self.run_test_with_dynamic_shape( - full(), inputs, use_example_tensors=False, check_dtype=False + full(), + inputs, + use_example_tensors=False, + check_dtype=False, + use_dynamo_tracer=False, ) @parameterized.expand( diff --git a/tests/py/dynamo/conversion/test_isinf_aten.py b/tests/py/dynamo/conversion/test_isinf_aten.py index d8051c1f41..25aad4381b 100644 --- a/tests/py/dynamo/conversion/test_isinf_aten.py +++ b/tests/py/dynamo/conversion/test_isinf_aten.py @@ -50,7 +50,12 @@ def forward(self, input): max_shape=(5, 3, 3), dtype=torch.float32, torch_tensor=torch.tensor( - ([[[2.7, float("-inf"), 1.1], [4.7, -2.3, float("inf")]]]), + ( + [ + [[2.7, float("-inf"), 1.1], [4.7, -2.3, float("inf")]], + [[2.7, float("-inf"), 1.1], [4.7, -2.3, float("inf")]], + ] + ), dtype=torch.float32, ).cuda(), ) @@ -72,7 +77,9 @@ def forward(self, input): opt_shape=(3, 2), max_shape=(5, 3), dtype=torch.int, - torch_tensor=torch.tensor(([[-3, 2]]), dtype=torch.int).cuda(), + torch_tensor=torch.tensor( + ([[-3, 2], [-2, 1], [1, 2]]), dtype=torch.int + ).cuda(), ) ] self.run_test_with_dynamic_shape( diff --git a/tests/py/dynamo/conversion/test_isnan_aten.py b/tests/py/dynamo/conversion/test_isnan_aten.py index 62ba24f319..ad913f9735 100644 --- a/tests/py/dynamo/conversion/test_isnan_aten.py +++ b/tests/py/dynamo/conversion/test_isnan_aten.py @@ -52,7 +52,18 @@ def forward(self, input): max_shape=(5, 3, 3), dtype=torch.float32, torch_tensor=torch.tensor( - ([[[3.2, float("nan"), 3.1], [float("inf"), 1.1, float("nan")]]]), + ( + [ + [ + [3.2, float("nan"), 3.1], + [float("inf"), 1.1, float("nan")], + ], + [ + [3.2, float("nan"), 3.1], + [float("inf"), 1.1, float("nan")], + ], + ] + ), dtype=torch.float32, ).cuda(), ) diff --git a/tests/py/dynamo/conversion/test_logical_and_aten.py b/tests/py/dynamo/conversion/test_logical_and_aten.py index 9ccd96e81c..64e8a4c839 100644 --- a/tests/py/dynamo/conversion/test_logical_and_aten.py +++ b/tests/py/dynamo/conversion/test_logical_and_aten.py @@ -37,7 +37,7 @@ def forward(self, lhs_val, rhs_val): ( "3d_dim_dtype_bool", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.bool, ), From 95ed60249c387e079f5926ff795fea94fb6be970 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sat, 19 Oct 2024 21:33:19 -0700 Subject: [PATCH 10/33] test --- py/torch_tensorrt/dynamo/_compiler.py | 12 +++-- py/torch_tensorrt/dynamo/_exporter.py | 2 +- py/torch_tensorrt/dynamo/_refit.py | 2 - py/torch_tensorrt/dynamo/_tracer.py | 2 - .../dynamo/conversion/_conversion.py | 51 +++---------------- tests/py/dynamo/conversion/harness.py | 4 -- 6 files changed, 15 insertions(+), 58 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 5cde0b9a15..973b8d3d5f 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -455,7 +455,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: if isinstance(ele, torch.fx.node.Node): if "val" not in ele.meta: raise ValueError( - "expect submodule output node has meta['val'] info" + f"node.name={ele.name}: meta['val'] does not exist, expect submodule output node has meta['val'] info" ) outputs_meta_val.append(ele.meta["val"]) elif isinstance(ele, tuple): @@ -463,13 +463,17 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: if isinstance(node, torch.fx.node.Node): if "val" not in ele.meta: raise ValueError( - "expect submodule output node has meta['val'] info" + f"{node.name=}: meta['val'] does not exist, expect submodule output node has meta['val'] info" ) outputs_meta_val.append(node.meta["val"]) else: - raise ValueError(f"not expected types: {type(node)=}") + raise ValueError( + f"expect torch.fx.node.Node type, got not expected types: {type(node)=}" + ) else: - raise ValueError(f"not expected types: {type(ele)=}") + raise ValueError( + f"expect torch.fx.node.Node or tuple of torch.fx.node.Node type, got not expected types: {type(ele)=}" + ) if name not in submodule_node_dict: raise ValueError( diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index d8df7fac5a..8bb67067b3 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -368,7 +368,7 @@ def inline_trt_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: if "val" not in trt_module_node.meta: raise ValueError( - f"trt_module_node: {trt_module_node.name} does not have the meta['val'] info" + f"trt_module_node: {trt_module_node.name} does not have the meta['val'] info, it should be set during dynamo compile_module step." ) num_outputs = len(trt_module_node.meta["val"]) # Insert a call_function node to perform inference on TRT engine diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 15e2d2705f..c27169a1a9 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -74,8 +74,6 @@ def construct_refit_mapping( output_dtypes = infer_module_output_dtypes( module, - # inputs, - # settings.device, truncate_double=settings.truncate_double, ) diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 2fb2b3080d..2c8745ee33 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -134,13 +134,11 @@ def get_dynamic_shapes(input: Input) -> dict[Any, Any]: max_shape = input.shape["max_shape"] assert len(min_shape) == len(opt_shape) == len(max_shape) for dim in range(len(min_shape)): - # reverse_dim = len(min_shape)-1 - dim if min_shape[dim] == opt_shape[dim] == max_shape[dim]: continue else: dynamic_dims[dim] = Dim( input.name + "_" + str(dim), - # input.name + "_" + str(reverse_dim), min=min_shape[dim], max=max_shape[dim], ) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index dea1aef0a9..589d94fbb2 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -17,12 +17,6 @@ ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule -# from torch_tensorrt.dynamo.utils import ( -# get_model_device, -# get_torch_inputs, -# unwrap_tensor_shape, -# ) - logger = logging.getLogger(__name__) @@ -38,62 +32,30 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype] output_dtypes.append(dtype._from(output_meta.dtype)) else: raise ValueError( - "meta['val'] does not exist, expect meta['val'] exists for each output node" + f"node.name={output.name}: node.meta['val'] does not exist, expect node.meta['val'] exists for each output node" ) elif isinstance(output, tuple): for ele in output: output_dtypes.extend(get_output_dtypes(ele)) else: raise ValueError( - f"got type {type(output)}, expected type is a torch.fx.node.Node or a tuple of torch.fx.node.Node" + f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple of torch.fx.node.Node" ) return output_dtypes def infer_module_output_dtypes( module: torch.fx.GraphModule, - # inputs: Sequence[Input], - # device: Device, - # kwarg_inputs: Optional[dict[str, Any]] = None, truncate_double: bool = False, ) -> List[dtype]: """ - This function performs model inference to determine the output shapes and output dtypes - and truncates them accordingly. inputs can be either arg_inputs or flattened input list. - If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input. + This function get the output dtypes from node.meta['val'] which was set during dynamo compile_module step + and truncates them accordingly. """ outputs = [node for node in module.graph.nodes if node.op == "output"] outputs = outputs[0].args return get_output_dtypes(outputs, truncate_double) - # TODO: We can also determine output dtypes from the module.graph based on node metadata. - # However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata, - # so we stick to the model inference approach currently. - # with unset_fake_temporarily(): - # # Get the device on which the model exists - # # For large models, this can be done on CPU to save GPU memory allocation for TRT. - # device = get_model_device(module) - # torch_inputs = get_torch_inputs(inputs, device) - # if kwarg_inputs is None: - # kwarg_inputs = {} - # torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) - # module_outputs = module(*torch_inputs, **torch_kwarg_inputs) - # if not isinstance(module_outputs, (list, tuple)): - # module_outputs = [module_outputs] - - # Int64 outputs can sometimes be generated from within other operators - # # such as aten.sum - such outputs can be truncated - # output_dtypes_ret = [] - - # for output_dtype in output_dtypes: - - # if truncate_double and output_dtype == dtype.float64: - # output_dtypes_ret.append(dtype.float32) - # else: - # output_dtypes_ret.append(dtype._from(output_dtype)) - - # return output_shapes, output_dtypes - def interpret_module_to_result( module: torch.fx.GraphModule, @@ -103,7 +65,7 @@ def interpret_module_to_result( kwarg_inputs: Optional[dict[str, Any]] = None, engine_cache: Optional[BaseEngineCache] = None, ) -> TRTInterpreterResult: - """Interpret an FX module to the output shapes and a TRTInterpreterResult + """Interpret an FX module to a TRTInterpreterResult Args: module: FX GraphModule to interpret inputs: Sequence of FLATTENED Tensors representing inputs to the module. It should include both @@ -113,7 +75,7 @@ def interpret_module_to_result( settings: Compilation settings engine_cache: Engine cache instance Returns: - (TRTInterpreterResult, List[Tuple[int]]) + TRTInterpreterResult """ output_dtypes = infer_module_output_dtypes( module, truncate_double=settings.truncate_double @@ -149,7 +111,6 @@ def convert_module( Returns: PythonTorchTensorRTModule or TorchTensorRTModule """ - interpreter_result = interpret_module_to_result( module, inputs, settings, engine_cache=engine_cache ) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index a692811c44..642140775f 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -312,8 +312,6 @@ def run_test( if check_dtype: output_dtypes = infer_module_output_dtypes( mod, - # input_specs, - # compilation_settings.device, truncate_double=compilation_settings.truncate_double, ) @@ -407,8 +405,6 @@ def run_test_with_dynamic_shape( if check_dtype: output_dtypes = infer_module_output_dtypes( mod, - # input_specs, - # compilation_settings.device, truncate_double=compilation_settings.truncate_double, ) From 120f30d0bc273eadc7725912ac6ef25951cc32ce Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 21 Oct 2024 10:04:26 -0700 Subject: [PATCH 11/33] test --- py/torch_tensorrt/dynamo/lowering/_decompositions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index d195ad81b8..90e7d21149 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -3,7 +3,8 @@ from typing import Any, Callable, Dict, List, Optional import torch -from torch._decomp import _decomp_table_to_post_autograd_aten, register_decomposition +from torch._decomp import register_decomposition +from torch._export.utils import _decomp_table_to_post_autograd_aten from torch._ops import OpOverload from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim From 424cbf7d16a21b7ff9608a8c92d3407118eab12a Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 21 Oct 2024 15:39:49 -0700 Subject: [PATCH 12/33] add run_test_with_dynamic_shape change --- tests/py/dynamo/conversion/harness.py | 20 +++++++++++-- tests/py/dynamo/conversion/test_ge_aten.py | 33 +++++++++++++++++++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 642140775f..4fc20ddd1a 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -3,7 +3,7 @@ import logging import time import unittest -from typing import Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple import torch import torch_tensorrt @@ -12,6 +12,7 @@ from torch_tensorrt import Input from torch_tensorrt._enums import dtype from torch_tensorrt.dynamo import _defaults +from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo._settings import CompilationSettings # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry @@ -223,10 +224,22 @@ def generate_graph( use_dynamo_tracer: bool, enable_passes: bool, propagate_shapes: bool = False, + torch_export_dynamic_shapes: Optional[Any] = None, ): mod = mod.eval() if use_dynamo_tracer: - exported_program = torch_tensorrt.dynamo.trace(mod, tuple(original_inputs)) + if torch_export_dynamic_shapes is not None: + device = default_device() + torch_export_inputs = get_torch_inputs(original_inputs, device) + exported_program = torch.export.export( + mod, + tuple(torch_export_inputs), + dynamic_shapes=torch_export_dynamic_shapes, + ) + else: + exported_program = torch_tensorrt.dynamo.trace( + mod, tuple(original_inputs) + ) exported_program = pre_export_lowering(exported_program) exported_program = exported_program.run_decompositions( get_decompositions(False) @@ -387,6 +400,8 @@ def run_test_with_dynamic_shape( propagate_shapes=False, check_dtype=True, make_refittable=False, + # this field is optional, in case user wants to specify custom dynamic_shapes rules for the testcase + torch_export_dynamic_shapes: Optional[Any] = None, ): mod = self.generate_graph( mod, @@ -394,6 +409,7 @@ def run_test_with_dynamic_shape( use_dynamo_tracer=use_dynamo_tracer, enable_passes=enable_passes, propagate_shapes=propagate_shapes, + torch_export_dynamic_shapes=torch_export_dynamic_shapes, ) # Previous instance of the interpreter auto-casted 64-bit inputs diff --git a/tests/py/dynamo/conversion/test_ge_aten.py b/tests/py/dynamo/conversion/test_ge_aten.py index a803c1c6b1..3ece1d125d 100644 --- a/tests/py/dynamo/conversion/test_ge_aten.py +++ b/tests/py/dynamo/conversion/test_ge_aten.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn from parameterized import parameterized +from torch.export import Dim from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input @@ -85,10 +86,40 @@ def forward(self, lhs_val, rhs_val): @parameterized.expand( [ - ("2d_2d", (2, 3), (4, 3), (5, 3), (2, 3), (4, 3), (5, 3)), ("3d_2d", (2, 2, 2), (2, 3, 2), (2, 4, 2), (2, 1), (3, 1), (4, 1)), ] ) + def test_ge_dynamic_tensor_torch_export(self, *args): + class ge(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.ge.Tensor(lhs_val, rhs_val) + + input_specs = [ + Input( + min_shape=args[1], + opt_shape=args[2], + max_shape=args[3], + ), + Input( + min_shape=args[4], + opt_shape=args[5], + max_shape=args[6], + ), + ] + dyn_dim = Dim("dyn_dim", min=2, max=4) + torch_export_dynamic_shapes = {"lhs_val": {1: dyn_dim}, "rhs_val": {0: dyn_dim}} + + self.run_test_with_dynamic_shape( + ge(), + input_specs, + torch_export_dynamic_shapes=torch_export_dynamic_shapes, + ) + + @parameterized.expand( + [ + ("2d_2d", (2, 3), (4, 3), (5, 3), (2, 3), (4, 3), (5, 3)), + ] + ) def test_ge_dynamic_tensor(self, *args): class ge(nn.Module): def forward(self, lhs_val, rhs_val): From ef54cfce4f63fad4e5ecfb0da827b5e25302550c Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 21 Oct 2024 16:52:16 -0700 Subject: [PATCH 13/33] split the PR, add dummy inference for converter test --- .../dynamo/conversion/_conversion.py | 52 +++++++++++++++++++ tests/py/dynamo/conversion/harness.py | 16 ++++-- tests/py/dynamo/conversion/test_acos_aten.py | 2 +- tests/py/dynamo/conversion/test_acosh_aten.py | 4 +- tests/py/dynamo/conversion/test_any.py | 8 +-- .../py/dynamo/conversion/test_arange_aten.py | 1 - tests/py/dynamo/conversion/test_cat_aten.py | 2 - tests/py/dynamo/conversion/test_full_aten.py | 6 +-- tests/py/dynamo/conversion/test_ge_aten.py | 33 +----------- tests/py/dynamo/conversion/test_isinf_aten.py | 11 +--- tests/py/dynamo/conversion/test_isnan_aten.py | 13 +---- .../conversion/test_logical_and_aten.py | 2 +- 12 files changed, 76 insertions(+), 74 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 589d94fbb2..db1cc4f5b5 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -6,6 +6,8 @@ import tensorrt as trt import torch from torch._subclasses.fake_tensor import FakeTensor +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily +from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input @@ -16,6 +18,7 @@ TRTInterpreterResult, ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule +from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs logger = logging.getLogger(__name__) @@ -57,6 +60,55 @@ def infer_module_output_dtypes( return get_output_dtypes(outputs, truncate_double) +def infer_module_output_dtypes_for_test( + module: torch.fx.GraphModule, + inputs: Sequence[Input], + device: Device, + kwarg_inputs: Optional[dict[str, Any]] = None, + truncate_double: bool = False, +) -> List[dtype]: + """ + This function performs model inference to determine the output dtypes + and truncates them accordingly. inputs can be either arg_inputs or flattened input list. + If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input. + """ + # TODO: We can also determine output dtypes from the module.graph based on node metadata. + # However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata, + # so we stick to the model inference approach currently. + with unset_fake_temporarily(): + # Get the device on which the model exists + # For large models, this can be done on CPU to save GPU memory allocation for TRT. + device = get_model_device(module) + torch_inputs = get_torch_inputs(inputs, device) + if kwarg_inputs is None: + kwarg_inputs = {} + torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) + module_outputs = module(*torch_inputs, **torch_kwarg_inputs) + if not isinstance(module_outputs, (list, tuple)): + module_outputs = [module_outputs] + + # Int64 outputs can sometimes be generated from within other operators + # such as aten.sum - such outputs can be truncated + output_dtypes = [] + for output in module_outputs: + output_ = output + # We don't need to check if output is nested here because the input module will be flattened + if not isinstance(output, torch.Tensor): + if isinstance(output, str): + raise ValueError( + f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)" + ) + else: + output_ = torch.tensor(output) + + if truncate_double and output_.dtype == dtype.float64: + output_dtypes.append(dtype.float32) + else: + output_dtypes.append(dtype._from(output_.dtype)) + + return output_dtypes + + def interpret_module_to_result( module: torch.fx.GraphModule, inputs: Sequence[Input], diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 31f74acdce..8920d969eb 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -17,7 +17,9 @@ # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter -from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes +from torch_tensorrt.dynamo.conversion._conversion import ( + infer_module_output_dtypes_for_test, +) from torch_tensorrt.dynamo.lowering import ( get_decompositions, post_lowering, @@ -273,7 +275,7 @@ def run_test( atol=ATOL, precision=dtype.f32, check_dtype=True, - use_dynamo_tracer=True, + use_dynamo_tracer=False, enable_passes=False, propagate_shapes=False, int32_reqd=False, @@ -326,8 +328,10 @@ def run_test( output_dtypes = None if check_dtype: - output_dtypes = infer_module_output_dtypes( + output_dtypes = infer_module_output_dtypes_for_test( mod, + input_specs, + compilation_settings.device, truncate_double=compilation_settings.truncate_double, ) @@ -399,7 +403,7 @@ def run_test_with_dynamic_shape( rtol=RTOL, atol=ATOL, output_dtypes=None, - use_dynamo_tracer=True, + use_dynamo_tracer=False, enable_passes=False, use_example_tensors=True, pyt_inputs=None, @@ -426,8 +430,10 @@ def run_test_with_dynamic_shape( ) if check_dtype: - output_dtypes = infer_module_output_dtypes( + output_dtypes = infer_module_output_dtypes_for_test( mod, + input_specs, + compilation_settings.device, truncate_double=compilation_settings.truncate_double, ) diff --git a/tests/py/dynamo/conversion/test_acos_aten.py b/tests/py/dynamo/conversion/test_acos_aten.py index 8e93e0a309..81b83bcc4a 100644 --- a/tests/py/dynamo/conversion/test_acos_aten.py +++ b/tests/py/dynamo/conversion/test_acos_aten.py @@ -66,7 +66,7 @@ def forward(self, input): ( "3d_dim_dtype_float", (1, 1, 1), - (2, 2, 3), + (1, 2, 3), (3, 3, 3), torch.float, torch.float, diff --git a/tests/py/dynamo/conversion/test_acosh_aten.py b/tests/py/dynamo/conversion/test_acosh_aten.py index dd19c188c5..090756ddfb 100644 --- a/tests/py/dynamo/conversion/test_acosh_aten.py +++ b/tests/py/dynamo/conversion/test_acosh_aten.py @@ -58,7 +58,7 @@ def forward(self, input): ( "3d_dim_dtype_float", (1, 1, 1), - (2, 2, 3), + (1, 2, 3), (3, 3, 3), torch.float, torch.float, @@ -66,7 +66,7 @@ def forward(self, input): ( "3d_dim_dtype_int32", (1, 1, 1), - (2, 2, 4), + (1, 2, 4), (2, 3, 5), torch.int32, torch.float, diff --git a/tests/py/dynamo/conversion/test_any.py b/tests/py/dynamo/conversion/test_any.py index 75620f7c34..1d1fc634ef 100644 --- a/tests/py/dynamo/conversion/test_any.py +++ b/tests/py/dynamo/conversion/test_any.py @@ -191,7 +191,7 @@ class TestAnyConverterDynamic(DispatchTestCase): ( "3d_dynamic_float", (2, 1, 1), - (2, 2, 2), + (2, 2, 1), (3, 2, 4), torch.float, ), @@ -234,7 +234,7 @@ def forward(self, x): ( "3d_dynamic_dim_float", (2, 1, 1), - (2, 2, 2), + (2, 2, 1), (3, 2, 4), torch.float, 2, @@ -252,7 +252,7 @@ def forward(self, x): ( "3d_dynamic_dim_bool", (2, 1, 1), - (2, 2, 2), + (2, 2, 1), (3, 2, 4), torch.bool, 0, @@ -285,7 +285,7 @@ def forward(self, x): ( "3d_dynamic_dims_float", (2, 1, 1), - (2, 2, 2), + (2, 2, 1), (3, 2, 4), torch.float, [1, 2], diff --git a/tests/py/dynamo/conversion/test_arange_aten.py b/tests/py/dynamo/conversion/test_arange_aten.py index cb3e4c6b51..1e1e9b3cc7 100644 --- a/tests/py/dynamo/conversion/test_arange_aten.py +++ b/tests/py/dynamo/conversion/test_arange_aten.py @@ -56,7 +56,6 @@ def forward(self, end_tensor): use_example_tensors=False, check_dtype=False, pyt_inputs=[pyt_input], - use_dynamo_tracer=False, ) diff --git a/tests/py/dynamo/conversion/test_cat_aten.py b/tests/py/dynamo/conversion/test_cat_aten.py index abe96af271..a9e4a45c81 100644 --- a/tests/py/dynamo/conversion/test_cat_aten.py +++ b/tests/py/dynamo/conversion/test_cat_aten.py @@ -42,14 +42,12 @@ def forward(self, x, y): min_shape=(16, 2, 3), opt_shape=(16, 3, 3), max_shape=(16, 32, 3), - name="x", ), Input( dtype=torch.float32, min_shape=(16, 2, 3), opt_shape=(16, 16, 3), max_shape=(16, 32, 3), - name="y", ), ] self.run_test_with_dynamic_shape( diff --git a/tests/py/dynamo/conversion/test_full_aten.py b/tests/py/dynamo/conversion/test_full_aten.py index 44d30af43e..29b48d1451 100644 --- a/tests/py/dynamo/conversion/test_full_aten.py +++ b/tests/py/dynamo/conversion/test_full_aten.py @@ -50,11 +50,7 @@ def forward(self, shape): ) ] self.run_test_with_dynamic_shape( - full(), - inputs, - use_example_tensors=False, - check_dtype=False, - use_dynamo_tracer=False, + full(), inputs, use_example_tensors=False, check_dtype=False ) @parameterized.expand( diff --git a/tests/py/dynamo/conversion/test_ge_aten.py b/tests/py/dynamo/conversion/test_ge_aten.py index 3ece1d125d..a803c1c6b1 100644 --- a/tests/py/dynamo/conversion/test_ge_aten.py +++ b/tests/py/dynamo/conversion/test_ge_aten.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn from parameterized import parameterized -from torch.export import Dim from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input @@ -84,40 +83,10 @@ def forward(self, lhs_val, rhs_val): inputs, ) - @parameterized.expand( - [ - ("3d_2d", (2, 2, 2), (2, 3, 2), (2, 4, 2), (2, 1), (3, 1), (4, 1)), - ] - ) - def test_ge_dynamic_tensor_torch_export(self, *args): - class ge(nn.Module): - def forward(self, lhs_val, rhs_val): - return torch.ops.aten.ge.Tensor(lhs_val, rhs_val) - - input_specs = [ - Input( - min_shape=args[1], - opt_shape=args[2], - max_shape=args[3], - ), - Input( - min_shape=args[4], - opt_shape=args[5], - max_shape=args[6], - ), - ] - dyn_dim = Dim("dyn_dim", min=2, max=4) - torch_export_dynamic_shapes = {"lhs_val": {1: dyn_dim}, "rhs_val": {0: dyn_dim}} - - self.run_test_with_dynamic_shape( - ge(), - input_specs, - torch_export_dynamic_shapes=torch_export_dynamic_shapes, - ) - @parameterized.expand( [ ("2d_2d", (2, 3), (4, 3), (5, 3), (2, 3), (4, 3), (5, 3)), + ("3d_2d", (2, 2, 2), (2, 3, 2), (2, 4, 2), (2, 1), (3, 1), (4, 1)), ] ) def test_ge_dynamic_tensor(self, *args): diff --git a/tests/py/dynamo/conversion/test_isinf_aten.py b/tests/py/dynamo/conversion/test_isinf_aten.py index 25aad4381b..d8051c1f41 100644 --- a/tests/py/dynamo/conversion/test_isinf_aten.py +++ b/tests/py/dynamo/conversion/test_isinf_aten.py @@ -50,12 +50,7 @@ def forward(self, input): max_shape=(5, 3, 3), dtype=torch.float32, torch_tensor=torch.tensor( - ( - [ - [[2.7, float("-inf"), 1.1], [4.7, -2.3, float("inf")]], - [[2.7, float("-inf"), 1.1], [4.7, -2.3, float("inf")]], - ] - ), + ([[[2.7, float("-inf"), 1.1], [4.7, -2.3, float("inf")]]]), dtype=torch.float32, ).cuda(), ) @@ -77,9 +72,7 @@ def forward(self, input): opt_shape=(3, 2), max_shape=(5, 3), dtype=torch.int, - torch_tensor=torch.tensor( - ([[-3, 2], [-2, 1], [1, 2]]), dtype=torch.int - ).cuda(), + torch_tensor=torch.tensor(([[-3, 2]]), dtype=torch.int).cuda(), ) ] self.run_test_with_dynamic_shape( diff --git a/tests/py/dynamo/conversion/test_isnan_aten.py b/tests/py/dynamo/conversion/test_isnan_aten.py index ad913f9735..62ba24f319 100644 --- a/tests/py/dynamo/conversion/test_isnan_aten.py +++ b/tests/py/dynamo/conversion/test_isnan_aten.py @@ -52,18 +52,7 @@ def forward(self, input): max_shape=(5, 3, 3), dtype=torch.float32, torch_tensor=torch.tensor( - ( - [ - [ - [3.2, float("nan"), 3.1], - [float("inf"), 1.1, float("nan")], - ], - [ - [3.2, float("nan"), 3.1], - [float("inf"), 1.1, float("nan")], - ], - ] - ), + ([[[3.2, float("nan"), 3.1], [float("inf"), 1.1, float("nan")]]]), dtype=torch.float32, ).cuda(), ) diff --git a/tests/py/dynamo/conversion/test_logical_and_aten.py b/tests/py/dynamo/conversion/test_logical_and_aten.py index 64e8a4c839..9ccd96e81c 100644 --- a/tests/py/dynamo/conversion/test_logical_and_aten.py +++ b/tests/py/dynamo/conversion/test_logical_and_aten.py @@ -37,7 +37,7 @@ def forward(self, lhs_val, rhs_val): ( "3d_dim_dtype_bool", (1, 1, 1), - (2, 2, 3), + (1, 2, 3), (3, 3, 3), torch.bool, ), From 14f5d615ea4845867f8ff721f7ea4cfb15501797 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 21 Oct 2024 20:02:56 -0700 Subject: [PATCH 14/33] test --- py/torch_tensorrt/dynamo/_compiler.py | 3 +- .../dynamo/conversion/_conversion.py | 3 + tests/py/dynamo/conversion/harness.py | 17 +- tests/py/dynamo/models/test_models_export.py | 466 +++++++++--------- 4 files changed, 240 insertions(+), 249 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index ff1f5b3ea5..95531321e5 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -456,6 +456,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: outputs = outputs[0].args outputs_meta_val = [] for ele in outputs: + breakpoint() # it can be a torch.fx.node.Node or a tuple of torch.fx.node.Node if isinstance(ele, torch.fx.node.Node): if "val" not in ele.meta: @@ -466,7 +467,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: elif isinstance(ele, tuple): for node in ele: if isinstance(node, torch.fx.node.Node): - if "val" not in ele.meta: + if "val" not in node.meta: raise ValueError( f"{node.name=}: meta['val'] does not exist, expect submodule output node has meta['val'] info" ) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index db1cc4f5b5..33bbfe8bb8 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -60,6 +60,9 @@ def infer_module_output_dtypes( return get_output_dtypes(outputs, truncate_double) +# this method is only used in our converter test to infer the module output dtypes via dummy inference +# which is due to fx.symbolic_trace does not have the meta['val'] info in the node +# TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace def infer_module_output_dtypes_for_test( module: torch.fx.GraphModule, inputs: Sequence[Input], diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 8920d969eb..610e424fef 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -12,7 +12,6 @@ from torch_tensorrt import Input from torch_tensorrt._enums import dtype from torch_tensorrt.dynamo import _defaults -from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo._settings import CompilationSettings # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry @@ -227,22 +226,10 @@ def generate_graph( enable_passes: bool, propagate_shapes: bool = False, settings: CompilationSettings = CompilationSettings(), - torch_export_dynamic_shapes: Optional[Any] = None, ): mod = mod.eval() if use_dynamo_tracer: - if torch_export_dynamic_shapes is not None: - device = default_device() - torch_export_inputs = get_torch_inputs(original_inputs, device) - exported_program = torch.export.export( - mod, - tuple(torch_export_inputs), - dynamic_shapes=torch_export_dynamic_shapes, - ) - else: - exported_program = torch_tensorrt.dynamo.trace( - mod, tuple(original_inputs) - ) + exported_program = torch_tensorrt.dynamo.trace(mod, tuple(original_inputs)) exported_program = pre_export_lowering(exported_program, settings) exported_program = exported_program.run_decompositions( get_decompositions(False) @@ -410,7 +397,6 @@ def run_test_with_dynamic_shape( propagate_shapes=False, check_dtype=True, make_refittable=False, - torch_export_dynamic_shapes=None, ): # Previous instance of the interpreter auto-casted 64-bit inputs @@ -426,7 +412,6 @@ def run_test_with_dynamic_shape( enable_passes=enable_passes, propagate_shapes=propagate_shapes, settings=compilation_settings, - torch_export_dynamic_shapes=torch_export_dynamic_shapes, ) if check_dtype: diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 2d0992ca8b..c0a4233e47 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -5,10 +5,12 @@ from importlib import metadata import pytest -import timm + +# import timm import torch import torch_tensorrt as torchtrt -import torchvision.models as models + +# import torchvision.models as models from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity from transformers import BertModel @@ -17,100 +19,100 @@ assertions = unittest.TestCase() -@pytest.mark.unit -def test_resnet18(ir): - model = models.resnet18(pretrained=True).eval().to("cuda") - input = torch.randn((1, 3, 224, 224)).to("cuda") - - compile_spec = { - "inputs": [ - torchtrt.Input( - input.shape, dtype=torch.float, format=torch.contiguous_format - ) - ], - "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, - "ir": ir, - "pass_through_build_failures": True, - "optimization_level": 1, - "min_block_size": 8, - "cache_built_engines": False, - "reuse_cached_engines": False, - } - - trt_mod = torchtrt.compile(model, **compile_spec) - cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) - assertions.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - # Clean up model env - torch._dynamo.reset() - - -@pytest.mark.unit -def test_mobilenet_v2(ir): - model = models.mobilenet_v2(pretrained=True).eval().to("cuda") - input = torch.randn((1, 3, 224, 224)).to("cuda") - - compile_spec = { - "inputs": [ - torchtrt.Input( - input.shape, dtype=torch.float, format=torch.contiguous_format - ) - ], - "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, - "ir": ir, - "pass_through_build_failures": True, - "optimization_level": 1, - "min_block_size": 8, - "cache_built_engines": False, - "reuse_cached_engines": False, - } - - trt_mod = torchtrt.compile(model, **compile_spec) - cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) - assertions.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - # Clean up model env - torch._dynamo.reset() - - -@pytest.mark.unit -def test_efficientnet_b0(ir): - model = timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda") - input = torch.randn((1, 3, 224, 224)).to("cuda") - - compile_spec = { - "inputs": [ - torchtrt.Input( - input.shape, dtype=torch.float, format=torch.contiguous_format - ) - ], - "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, - "ir": ir, - "pass_through_build_failures": True, - "optimization_level": 1, - "min_block_size": 8, - "cache_built_engines": False, - "reuse_cached_engines": False, - } - - trt_mod = torchtrt.compile(model, **compile_spec) - cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) - assertions.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - # Clean up model env - torch._dynamo.reset() +# @pytest.mark.unit +# def test_resnet18(ir): +# model = models.resnet18(pretrained=True).eval().to("cuda") +# input = torch.randn((1, 3, 224, 224)).to("cuda") + +# compile_spec = { +# "inputs": [ +# torchtrt.Input( +# input.shape, dtype=torch.float, format=torch.contiguous_format +# ) +# ], +# "device": torchtrt.Device("cuda:0"), +# "enabled_precisions": {torch.float}, +# "ir": ir, +# "pass_through_build_failures": True, +# "optimization_level": 1, +# "min_block_size": 8, +# "cache_built_engines": False, +# "reuse_cached_engines": False, +# } + +# trt_mod = torchtrt.compile(model, **compile_spec) +# cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) +# assertions.assertTrue( +# cos_sim > COSINE_THRESHOLD, +# msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", +# ) + +# # Clean up model env +# torch._dynamo.reset() + + +# @pytest.mark.unit +# def test_mobilenet_v2(ir): +# model = models.mobilenet_v2(pretrained=True).eval().to("cuda") +# input = torch.randn((1, 3, 224, 224)).to("cuda") + +# compile_spec = { +# "inputs": [ +# torchtrt.Input( +# input.shape, dtype=torch.float, format=torch.contiguous_format +# ) +# ], +# "device": torchtrt.Device("cuda:0"), +# "enabled_precisions": {torch.float}, +# "ir": ir, +# "pass_through_build_failures": True, +# "optimization_level": 1, +# "min_block_size": 8, +# "cache_built_engines": False, +# "reuse_cached_engines": False, +# } + +# trt_mod = torchtrt.compile(model, **compile_spec) +# cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) +# assertions.assertTrue( +# cos_sim > COSINE_THRESHOLD, +# msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", +# ) + +# # Clean up model env +# torch._dynamo.reset() + + +# @pytest.mark.unit +# def test_efficientnet_b0(ir): +# model = timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda") +# input = torch.randn((1, 3, 224, 224)).to("cuda") + +# compile_spec = { +# "inputs": [ +# torchtrt.Input( +# input.shape, dtype=torch.float, format=torch.contiguous_format +# ) +# ], +# "device": torchtrt.Device("cuda:0"), +# "enabled_precisions": {torch.float}, +# "ir": ir, +# "pass_through_build_failures": True, +# "optimization_level": 1, +# "min_block_size": 8, +# "cache_built_engines": False, +# "reuse_cached_engines": False, +# } + +# trt_mod = torchtrt.compile(model, **compile_spec) +# cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) +# assertions.assertTrue( +# cos_sim > COSINE_THRESHOLD, +# msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", +# ) + +# # Clean up model env +# torch._dynamo.reset() @pytest.mark.unit @@ -162,139 +164,139 @@ def test_bert_base_uncased(ir): torch._dynamo.reset() -@pytest.mark.unit -def test_resnet18_half(ir): - model = models.resnet18(pretrained=True).eval().to("cuda").half() - input = torch.randn((1, 3, 224, 224)).to("cuda").half() - - compile_spec = { - "inputs": [ - torchtrt.Input( - input.shape, dtype=torch.half, format=torch.contiguous_format - ) - ], - "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.half}, - "ir": ir, - "pass_through_build_failures": True, - "optimization_level": 1, - "min_block_size": 8, - "cache_built_engines": False, - "reuse_cached_engines": False, - } - - trt_mod = torchtrt.compile(model, **compile_spec) - cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) - assertions.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - # Clean up model env - torch._dynamo.reset() - - -@unittest.skipIf( - torch.cuda.get_device_capability() < (8, 9), - "FP8 quantization requires compute capability 8.9 or later", -) -@unittest.skipIf( - not importlib.util.find_spec("modelopt"), - "ModelOpt is required to run this test", -) -@pytest.mark.unit -def test_base_fp8(ir): - import modelopt.torch.quantization as mtq - from modelopt.torch.quantization.utils import export_torch_mode - from torch.export._trace import _export - - class SimpleNetwork(torch.nn.Module): - def __init__(self): - super(SimpleNetwork, self).__init__() - self.linear1 = torch.nn.Linear(in_features=10, out_features=5) - self.linear2 = torch.nn.Linear(in_features=5, out_features=1) - - def forward(self, x): - x = self.linear1(x) - x = torch.nn.ReLU()(x) - x = self.linear2(x) - return x - - def calibrate_loop(model): - """Simple calibration function for testing.""" - model(input_tensor) - - input_tensor = torch.randn(1, 10).cuda() - model = SimpleNetwork().eval().cuda() - - quant_cfg = mtq.FP8_DEFAULT_CFG - mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - # model has FP8 qdq nodes at this point - output_pyt = model(input_tensor) - - with torch.no_grad(): - with export_torch_mode(): - exp_program = _export(model, (input_tensor,)) - trt_model = torchtrt.dynamo.compile( - exp_program, - inputs=[input_tensor], - enabled_precisions={torch.float8_e4m3fn}, - min_block_size=1, - debug=True, - cache_built_engines=False, - reuse_cached_engines=False, - ) - outputs_trt = trt_model(input_tensor) - assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) - - -@unittest.skipIf( - platform.system() != "Linux" - or not importlib.util.find_spec("modelopt") - or Version(metadata.version("nvidia-modelopt")) < Version("0.17.0"), - "modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux", -) -@pytest.mark.unit -def test_base_int8(ir): - import modelopt.torch.quantization as mtq - from modelopt.torch.quantization.utils import export_torch_mode - from torch.export._trace import _export - - class SimpleNetwork(torch.nn.Module): - def __init__(self): - super(SimpleNetwork, self).__init__() - self.linear1 = torch.nn.Linear(in_features=10, out_features=5) - self.linear2 = torch.nn.Linear(in_features=5, out_features=1) - - def forward(self, x): - x = self.linear1(x) - x = torch.nn.ReLU()(x) - x = self.linear2(x) - return x - - def calibrate_loop(model): - """Simple calibration function for testing.""" - model(input_tensor) - - input_tensor = torch.randn(1, 10).cuda() - model = SimpleNetwork().eval().cuda() - - quant_cfg = mtq.INT8_DEFAULT_CFG - mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - # model has INT8 qdq nodes at this point - output_pyt = model(input_tensor) - - with torch.no_grad(): - with export_torch_mode(): - exp_program = _export(model, (input_tensor,)) - trt_model = torchtrt.dynamo.compile( - exp_program, - inputs=[input_tensor], - enabled_precisions={torch.int8}, - min_block_size=1, - debug=True, - cache_built_engines=False, - reuse_cached_engines=False, - ) - outputs_trt = trt_model(input_tensor) - assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) +# @pytest.mark.unit +# def test_resnet18_half(ir): +# model = models.resnet18(pretrained=True).eval().to("cuda").half() +# input = torch.randn((1, 3, 224, 224)).to("cuda").half() + +# compile_spec = { +# "inputs": [ +# torchtrt.Input( +# input.shape, dtype=torch.half, format=torch.contiguous_format +# ) +# ], +# "device": torchtrt.Device("cuda:0"), +# "enabled_precisions": {torch.half}, +# "ir": ir, +# "pass_through_build_failures": True, +# "optimization_level": 1, +# "min_block_size": 8, +# "cache_built_engines": False, +# "reuse_cached_engines": False, +# } + +# trt_mod = torchtrt.compile(model, **compile_spec) +# cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) +# assertions.assertTrue( +# cos_sim > COSINE_THRESHOLD, +# msg=f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", +# ) + +# # Clean up model env +# torch._dynamo.reset() + + +# @unittest.skipIf( +# torch.cuda.get_device_capability() < (8, 9), +# "FP8 quantization requires compute capability 8.9 or later", +# ) +# @unittest.skipIf( +# not importlib.util.find_spec("modelopt"), +# "ModelOpt is required to run this test", +# ) +# @pytest.mark.unit +# def test_base_fp8(ir): +# import modelopt.torch.quantization as mtq +# from modelopt.torch.quantization.utils import export_torch_mode +# from torch.export._trace import _export + +# class SimpleNetwork(torch.nn.Module): +# def __init__(self): +# super(SimpleNetwork, self).__init__() +# self.linear1 = torch.nn.Linear(in_features=10, out_features=5) +# self.linear2 = torch.nn.Linear(in_features=5, out_features=1) + +# def forward(self, x): +# x = self.linear1(x) +# x = torch.nn.ReLU()(x) +# x = self.linear2(x) +# return x + +# def calibrate_loop(model): +# """Simple calibration function for testing.""" +# model(input_tensor) + +# input_tensor = torch.randn(1, 10).cuda() +# model = SimpleNetwork().eval().cuda() + +# quant_cfg = mtq.FP8_DEFAULT_CFG +# mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) +# # model has FP8 qdq nodes at this point +# output_pyt = model(input_tensor) + +# with torch.no_grad(): +# with export_torch_mode(): +# exp_program = _export(model, (input_tensor,)) +# trt_model = torchtrt.dynamo.compile( +# exp_program, +# inputs=[input_tensor], +# enabled_precisions={torch.float8_e4m3fn}, +# min_block_size=1, +# debug=True, +# cache_built_engines=False, +# reuse_cached_engines=False, +# ) +# outputs_trt = trt_model(input_tensor) +# assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) + + +# @unittest.skipIf( +# platform.system() != "Linux" +# or not importlib.util.find_spec("modelopt") +# or Version(metadata.version("nvidia-modelopt")) < Version("0.17.0"), +# "modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux", +# ) +# @pytest.mark.unit +# def test_base_int8(ir): +# import modelopt.torch.quantization as mtq +# from modelopt.torch.quantization.utils import export_torch_mode +# from torch.export._trace import _export + +# class SimpleNetwork(torch.nn.Module): +# def __init__(self): +# super(SimpleNetwork, self).__init__() +# self.linear1 = torch.nn.Linear(in_features=10, out_features=5) +# self.linear2 = torch.nn.Linear(in_features=5, out_features=1) + +# def forward(self, x): +# x = self.linear1(x) +# x = torch.nn.ReLU()(x) +# x = self.linear2(x) +# return x + +# def calibrate_loop(model): +# """Simple calibration function for testing.""" +# model(input_tensor) + +# input_tensor = torch.randn(1, 10).cuda() +# model = SimpleNetwork().eval().cuda() + +# quant_cfg = mtq.INT8_DEFAULT_CFG +# mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) +# # model has INT8 qdq nodes at this point +# output_pyt = model(input_tensor) + +# with torch.no_grad(): +# with export_torch_mode(): +# exp_program = _export(model, (input_tensor,)) +# trt_model = torchtrt.dynamo.compile( +# exp_program, +# inputs=[input_tensor], +# enabled_precisions={torch.int8}, +# min_block_size=1, +# debug=True, +# cache_built_engines=False, +# reuse_cached_engines=False, +# ) +# outputs_trt = trt_model(input_tensor) +# assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) From 756395940d8a8ad6943c5b16e01c92934f1ef0b4 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 21 Oct 2024 20:06:17 -0700 Subject: [PATCH 15/33] test --- tests/py/dynamo/models/test_models_export.py | 466 +++++++++---------- 1 file changed, 232 insertions(+), 234 deletions(-) diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index c0a4233e47..2d0992ca8b 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -5,12 +5,10 @@ from importlib import metadata import pytest - -# import timm +import timm import torch import torch_tensorrt as torchtrt - -# import torchvision.models as models +import torchvision.models as models from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity from transformers import BertModel @@ -19,100 +17,100 @@ assertions = unittest.TestCase() -# @pytest.mark.unit -# def test_resnet18(ir): -# model = models.resnet18(pretrained=True).eval().to("cuda") -# input = torch.randn((1, 3, 224, 224)).to("cuda") - -# compile_spec = { -# "inputs": [ -# torchtrt.Input( -# input.shape, dtype=torch.float, format=torch.contiguous_format -# ) -# ], -# "device": torchtrt.Device("cuda:0"), -# "enabled_precisions": {torch.float}, -# "ir": ir, -# "pass_through_build_failures": True, -# "optimization_level": 1, -# "min_block_size": 8, -# "cache_built_engines": False, -# "reuse_cached_engines": False, -# } - -# trt_mod = torchtrt.compile(model, **compile_spec) -# cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) -# assertions.assertTrue( -# cos_sim > COSINE_THRESHOLD, -# msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", -# ) - -# # Clean up model env -# torch._dynamo.reset() - - -# @pytest.mark.unit -# def test_mobilenet_v2(ir): -# model = models.mobilenet_v2(pretrained=True).eval().to("cuda") -# input = torch.randn((1, 3, 224, 224)).to("cuda") - -# compile_spec = { -# "inputs": [ -# torchtrt.Input( -# input.shape, dtype=torch.float, format=torch.contiguous_format -# ) -# ], -# "device": torchtrt.Device("cuda:0"), -# "enabled_precisions": {torch.float}, -# "ir": ir, -# "pass_through_build_failures": True, -# "optimization_level": 1, -# "min_block_size": 8, -# "cache_built_engines": False, -# "reuse_cached_engines": False, -# } - -# trt_mod = torchtrt.compile(model, **compile_spec) -# cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) -# assertions.assertTrue( -# cos_sim > COSINE_THRESHOLD, -# msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", -# ) - -# # Clean up model env -# torch._dynamo.reset() - - -# @pytest.mark.unit -# def test_efficientnet_b0(ir): -# model = timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda") -# input = torch.randn((1, 3, 224, 224)).to("cuda") - -# compile_spec = { -# "inputs": [ -# torchtrt.Input( -# input.shape, dtype=torch.float, format=torch.contiguous_format -# ) -# ], -# "device": torchtrt.Device("cuda:0"), -# "enabled_precisions": {torch.float}, -# "ir": ir, -# "pass_through_build_failures": True, -# "optimization_level": 1, -# "min_block_size": 8, -# "cache_built_engines": False, -# "reuse_cached_engines": False, -# } - -# trt_mod = torchtrt.compile(model, **compile_spec) -# cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) -# assertions.assertTrue( -# cos_sim > COSINE_THRESHOLD, -# msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", -# ) - -# # Clean up model env -# torch._dynamo.reset() +@pytest.mark.unit +def test_resnet18(ir): + model = models.resnet18(pretrained=True).eval().to("cuda") + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + +@pytest.mark.unit +def test_mobilenet_v2(ir): + model = models.mobilenet_v2(pretrained=True).eval().to("cuda") + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + +@pytest.mark.unit +def test_efficientnet_b0(ir): + model = timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda") + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() @pytest.mark.unit @@ -164,139 +162,139 @@ def test_bert_base_uncased(ir): torch._dynamo.reset() -# @pytest.mark.unit -# def test_resnet18_half(ir): -# model = models.resnet18(pretrained=True).eval().to("cuda").half() -# input = torch.randn((1, 3, 224, 224)).to("cuda").half() - -# compile_spec = { -# "inputs": [ -# torchtrt.Input( -# input.shape, dtype=torch.half, format=torch.contiguous_format -# ) -# ], -# "device": torchtrt.Device("cuda:0"), -# "enabled_precisions": {torch.half}, -# "ir": ir, -# "pass_through_build_failures": True, -# "optimization_level": 1, -# "min_block_size": 8, -# "cache_built_engines": False, -# "reuse_cached_engines": False, -# } - -# trt_mod = torchtrt.compile(model, **compile_spec) -# cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) -# assertions.assertTrue( -# cos_sim > COSINE_THRESHOLD, -# msg=f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", -# ) - -# # Clean up model env -# torch._dynamo.reset() - - -# @unittest.skipIf( -# torch.cuda.get_device_capability() < (8, 9), -# "FP8 quantization requires compute capability 8.9 or later", -# ) -# @unittest.skipIf( -# not importlib.util.find_spec("modelopt"), -# "ModelOpt is required to run this test", -# ) -# @pytest.mark.unit -# def test_base_fp8(ir): -# import modelopt.torch.quantization as mtq -# from modelopt.torch.quantization.utils import export_torch_mode -# from torch.export._trace import _export - -# class SimpleNetwork(torch.nn.Module): -# def __init__(self): -# super(SimpleNetwork, self).__init__() -# self.linear1 = torch.nn.Linear(in_features=10, out_features=5) -# self.linear2 = torch.nn.Linear(in_features=5, out_features=1) - -# def forward(self, x): -# x = self.linear1(x) -# x = torch.nn.ReLU()(x) -# x = self.linear2(x) -# return x - -# def calibrate_loop(model): -# """Simple calibration function for testing.""" -# model(input_tensor) - -# input_tensor = torch.randn(1, 10).cuda() -# model = SimpleNetwork().eval().cuda() - -# quant_cfg = mtq.FP8_DEFAULT_CFG -# mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) -# # model has FP8 qdq nodes at this point -# output_pyt = model(input_tensor) - -# with torch.no_grad(): -# with export_torch_mode(): -# exp_program = _export(model, (input_tensor,)) -# trt_model = torchtrt.dynamo.compile( -# exp_program, -# inputs=[input_tensor], -# enabled_precisions={torch.float8_e4m3fn}, -# min_block_size=1, -# debug=True, -# cache_built_engines=False, -# reuse_cached_engines=False, -# ) -# outputs_trt = trt_model(input_tensor) -# assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) - - -# @unittest.skipIf( -# platform.system() != "Linux" -# or not importlib.util.find_spec("modelopt") -# or Version(metadata.version("nvidia-modelopt")) < Version("0.17.0"), -# "modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux", -# ) -# @pytest.mark.unit -# def test_base_int8(ir): -# import modelopt.torch.quantization as mtq -# from modelopt.torch.quantization.utils import export_torch_mode -# from torch.export._trace import _export - -# class SimpleNetwork(torch.nn.Module): -# def __init__(self): -# super(SimpleNetwork, self).__init__() -# self.linear1 = torch.nn.Linear(in_features=10, out_features=5) -# self.linear2 = torch.nn.Linear(in_features=5, out_features=1) - -# def forward(self, x): -# x = self.linear1(x) -# x = torch.nn.ReLU()(x) -# x = self.linear2(x) -# return x - -# def calibrate_loop(model): -# """Simple calibration function for testing.""" -# model(input_tensor) - -# input_tensor = torch.randn(1, 10).cuda() -# model = SimpleNetwork().eval().cuda() - -# quant_cfg = mtq.INT8_DEFAULT_CFG -# mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) -# # model has INT8 qdq nodes at this point -# output_pyt = model(input_tensor) - -# with torch.no_grad(): -# with export_torch_mode(): -# exp_program = _export(model, (input_tensor,)) -# trt_model = torchtrt.dynamo.compile( -# exp_program, -# inputs=[input_tensor], -# enabled_precisions={torch.int8}, -# min_block_size=1, -# debug=True, -# cache_built_engines=False, -# reuse_cached_engines=False, -# ) -# outputs_trt = trt_model(input_tensor) -# assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) +@pytest.mark.unit +def test_resnet18_half(ir): + model = models.resnet18(pretrained=True).eval().to("cuda").half() + input = torch.randn((1, 3, 224, 224)).to("cuda").half() + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.half, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.half}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + +@unittest.skipIf( + torch.cuda.get_device_capability() < (8, 9), + "FP8 quantization requires compute capability 8.9 or later", +) +@unittest.skipIf( + not importlib.util.find_spec("modelopt"), + "ModelOpt is required to run this test", +) +@pytest.mark.unit +def test_base_fp8(ir): + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + from torch.export._trace import _export + + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear1 = torch.nn.Linear(in_features=10, out_features=5) + self.linear2 = torch.nn.Linear(in_features=5, out_features=1) + + def forward(self, x): + x = self.linear1(x) + x = torch.nn.ReLU()(x) + x = self.linear2(x) + return x + + def calibrate_loop(model): + """Simple calibration function for testing.""" + model(input_tensor) + + input_tensor = torch.randn(1, 10).cuda() + model = SimpleNetwork().eval().cuda() + + quant_cfg = mtq.FP8_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has FP8 qdq nodes at this point + output_pyt = model(input_tensor) + + with torch.no_grad(): + with export_torch_mode(): + exp_program = _export(model, (input_tensor,)) + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + enabled_precisions={torch.float8_e4m3fn}, + min_block_size=1, + debug=True, + cache_built_engines=False, + reuse_cached_engines=False, + ) + outputs_trt = trt_model(input_tensor) + assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) + + +@unittest.skipIf( + platform.system() != "Linux" + or not importlib.util.find_spec("modelopt") + or Version(metadata.version("nvidia-modelopt")) < Version("0.17.0"), + "modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux", +) +@pytest.mark.unit +def test_base_int8(ir): + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + from torch.export._trace import _export + + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear1 = torch.nn.Linear(in_features=10, out_features=5) + self.linear2 = torch.nn.Linear(in_features=5, out_features=1) + + def forward(self, x): + x = self.linear1(x) + x = torch.nn.ReLU()(x) + x = self.linear2(x) + return x + + def calibrate_loop(model): + """Simple calibration function for testing.""" + model(input_tensor) + + input_tensor = torch.randn(1, 10).cuda() + model = SimpleNetwork().eval().cuda() + + quant_cfg = mtq.INT8_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has INT8 qdq nodes at this point + output_pyt = model(input_tensor) + + with torch.no_grad(): + with export_torch_mode(): + exp_program = _export(model, (input_tensor,)) + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + enabled_precisions={torch.int8}, + min_block_size=1, + debug=True, + cache_built_engines=False, + reuse_cached_engines=False, + ) + outputs_trt = trt_model(input_tensor) + assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) From 77355f0a6ed5eee8bc5dbca8be4568612a29b963 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 21 Oct 2024 20:10:17 -0700 Subject: [PATCH 16/33] test --- py/torch_tensorrt/dynamo/_compiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 95531321e5..cac615d6ae 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -456,7 +456,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: outputs = outputs[0].args outputs_meta_val = [] for ele in outputs: - breakpoint() # it can be a torch.fx.node.Node or a tuple of torch.fx.node.Node if isinstance(ele, torch.fx.node.Node): if "val" not in ele.meta: From 891e9638207b27f24b384958ba8f694ec1f50efb Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 21 Oct 2024 22:09:45 -0700 Subject: [PATCH 17/33] enable converter non dynamic shape tests to use dynamo tracer enable converter dynamic shape tests to use dynamoc tracer batch by batch current batch: test_a*.py --- tests/py/dynamo/conversion/harness.py | 69 +++++++++++++++---- tests/py/dynamo/conversion/test_acos_aten.py | 2 +- tests/py/dynamo/conversion/test_acosh_aten.py | 4 +- tests/py/dynamo/conversion/test_any.py | 8 +-- .../py/dynamo/conversion/test_arange_aten.py | 1 + tests/py/dynamo/conversion/test_asin_aten.py | 2 +- tests/py/dynamo/conversion/test_asinh_aten.py | 10 +-- tests/py/dynamo/conversion/test_atan2_aten.py | 6 +- tests/py/dynamo/conversion/test_atan_aten.py | 4 +- 9 files changed, 74 insertions(+), 32 deletions(-) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 610e424fef..f6251f3913 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -12,11 +12,14 @@ from torch_tensorrt import Input from torch_tensorrt._enums import dtype from torch_tensorrt.dynamo import _defaults +from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo._tracer import get_dynamic_shapes_args # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter from torch_tensorrt.dynamo.conversion._conversion import ( + infer_module_output_dtypes, infer_module_output_dtypes_for_test, ) from torch_tensorrt.dynamo.lowering import ( @@ -30,6 +33,26 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) +# this is to enable dynamo tracer as Truein the converter test files batch by batch +def get_use_dynamo_tracer(use_dynamo_tracer: Any) -> bool: + # if in our converter tests we specifically set use_dynamo_tracer field, honor it + if use_dynamo_tracer is not None and isinstance(use_dynamo_tracer, bool): + return use_dynamo_tracer + + # if in our converter tests, we did not specify use_dynamo_tracer field + import inspect + import os + import re + + filename = os.path.basename(inspect.stack()[2].filename) + # enable converter test files which starts with test_a*.py to use dynamo tracer + pattern = re.compile("^test_([a])+") + if pattern.match(filename): + return True + else: + return False + + def fetch_attr(mod, target): """ Fetch an attribute from the ``Module`` hierarchy of ``mod.module``. @@ -226,10 +249,21 @@ def generate_graph( enable_passes: bool, propagate_shapes: bool = False, settings: CompilationSettings = CompilationSettings(), + torch_export_dynamic_shapes: Optional[Any] = None, ): mod = mod.eval() if use_dynamo_tracer: - exported_program = torch_tensorrt.dynamo.trace(mod, tuple(original_inputs)) + if torch_export_dynamic_shapes is None: + torch_export_dynamic_shapes = get_dynamic_shapes_args( + mod, original_inputs + ) + device = default_device() + torch_export_inputs = get_torch_inputs(original_inputs, device) + exported_program = torch.export.export( + mod, + tuple(torch_export_inputs), + dynamic_shapes=torch_export_dynamic_shapes, + ) exported_program = pre_export_lowering(exported_program, settings) exported_program = exported_program.run_decompositions( get_decompositions(False) @@ -262,7 +296,6 @@ def run_test( atol=ATOL, precision=dtype.f32, check_dtype=True, - use_dynamo_tracer=False, enable_passes=False, propagate_shapes=False, int32_reqd=False, @@ -281,7 +314,7 @@ def run_test( mod = self.generate_graph( mod, inputs, - use_dynamo_tracer=use_dynamo_tracer, + use_dynamo_tracer=True, enable_passes=enable_passes, propagate_shapes=propagate_shapes, settings=compilation_settings, @@ -315,10 +348,8 @@ def run_test( output_dtypes = None if check_dtype: - output_dtypes = infer_module_output_dtypes_for_test( + output_dtypes = infer_module_output_dtypes( mod, - input_specs, - compilation_settings.device, truncate_double=compilation_settings.truncate_double, ) @@ -390,21 +421,24 @@ def run_test_with_dynamic_shape( rtol=RTOL, atol=ATOL, output_dtypes=None, - use_dynamo_tracer=False, + use_dynamo_tracer=None, enable_passes=False, use_example_tensors=True, pyt_inputs=None, propagate_shapes=False, check_dtype=True, make_refittable=False, + torch_export_dynamic_shapes=None, ): + # TODO: lan to remove this and set use_dynamo_traccer to True by default + # once all the converter test files are moved to use_dynamo_tracer + use_dynamo_tracer = get_use_dynamo_tracer(use_dynamo_tracer) # Previous instance of the interpreter auto-casted 64-bit inputs # We replicate this behavior here compilation_settings = CompilationSettings( truncate_double=True, make_refittable=make_refittable ) - mod = self.generate_graph( mod, input_specs, @@ -412,15 +446,22 @@ def run_test_with_dynamic_shape( enable_passes=enable_passes, propagate_shapes=propagate_shapes, settings=compilation_settings, + torch_export_dynamic_shapes=torch_export_dynamic_shapes, ) if check_dtype: - output_dtypes = infer_module_output_dtypes_for_test( - mod, - input_specs, - compilation_settings.device, - truncate_double=compilation_settings.truncate_double, - ) + if use_dynamo_tracer: + output_dtypes = infer_module_output_dtypes( + mod, + truncate_double=compilation_settings.truncate_double, + ) + else: + output_dtypes = infer_module_output_dtypes_for_test( + mod, + input_specs, + compilation_settings.device, + truncate_double=compilation_settings.truncate_double, + ) interp = TRTInterpreter( mod, diff --git a/tests/py/dynamo/conversion/test_acos_aten.py b/tests/py/dynamo/conversion/test_acos_aten.py index 81b83bcc4a..8e93e0a309 100644 --- a/tests/py/dynamo/conversion/test_acos_aten.py +++ b/tests/py/dynamo/conversion/test_acos_aten.py @@ -66,7 +66,7 @@ def forward(self, input): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, diff --git a/tests/py/dynamo/conversion/test_acosh_aten.py b/tests/py/dynamo/conversion/test_acosh_aten.py index 090756ddfb..dd19c188c5 100644 --- a/tests/py/dynamo/conversion/test_acosh_aten.py +++ b/tests/py/dynamo/conversion/test_acosh_aten.py @@ -58,7 +58,7 @@ def forward(self, input): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, @@ -66,7 +66,7 @@ def forward(self, input): ( "3d_dim_dtype_int32", (1, 1, 1), - (1, 2, 4), + (2, 2, 4), (2, 3, 5), torch.int32, torch.float, diff --git a/tests/py/dynamo/conversion/test_any.py b/tests/py/dynamo/conversion/test_any.py index 1d1fc634ef..75620f7c34 100644 --- a/tests/py/dynamo/conversion/test_any.py +++ b/tests/py/dynamo/conversion/test_any.py @@ -191,7 +191,7 @@ class TestAnyConverterDynamic(DispatchTestCase): ( "3d_dynamic_float", (2, 1, 1), - (2, 2, 1), + (2, 2, 2), (3, 2, 4), torch.float, ), @@ -234,7 +234,7 @@ def forward(self, x): ( "3d_dynamic_dim_float", (2, 1, 1), - (2, 2, 1), + (2, 2, 2), (3, 2, 4), torch.float, 2, @@ -252,7 +252,7 @@ def forward(self, x): ( "3d_dynamic_dim_bool", (2, 1, 1), - (2, 2, 1), + (2, 2, 2), (3, 2, 4), torch.bool, 0, @@ -285,7 +285,7 @@ def forward(self, x): ( "3d_dynamic_dims_float", (2, 1, 1), - (2, 2, 1), + (2, 2, 2), (3, 2, 4), torch.float, [1, 2], diff --git a/tests/py/dynamo/conversion/test_arange_aten.py b/tests/py/dynamo/conversion/test_arange_aten.py index 1e1e9b3cc7..cb3e4c6b51 100644 --- a/tests/py/dynamo/conversion/test_arange_aten.py +++ b/tests/py/dynamo/conversion/test_arange_aten.py @@ -56,6 +56,7 @@ def forward(self, end_tensor): use_example_tensors=False, check_dtype=False, pyt_inputs=[pyt_input], + use_dynamo_tracer=False, ) diff --git a/tests/py/dynamo/conversion/test_asin_aten.py b/tests/py/dynamo/conversion/test_asin_aten.py index 2b8eb84144..f520e5c7a3 100644 --- a/tests/py/dynamo/conversion/test_asin_aten.py +++ b/tests/py/dynamo/conversion/test_asin_aten.py @@ -66,7 +66,7 @@ def forward(self, input): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, diff --git a/tests/py/dynamo/conversion/test_asinh_aten.py b/tests/py/dynamo/conversion/test_asinh_aten.py index c5fdc71883..025cdbefb5 100644 --- a/tests/py/dynamo/conversion/test_asinh_aten.py +++ b/tests/py/dynamo/conversion/test_asinh_aten.py @@ -11,9 +11,9 @@ class TestAsinhConverter(DispatchTestCase): @parameterized.expand( [ ((10,), torch.float), - ((1, 20), torch.float), - ((2, 3, 4), torch.float), - ((2, 3, 4, 5), torch.float), + # ((1, 20), torch.float), + # ((2, 3, 4), torch.float), + # ((2, 3, 4, 5), torch.float), ] ) def test_asinh_float(self, input_shape, dtype): @@ -58,7 +58,7 @@ def forward(self, input): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, @@ -66,7 +66,7 @@ def forward(self, input): ( "3d_dim_dtype_int32", (1, 1, 1), - (1, 2, 4), + (2, 2, 4), (2, 3, 5), torch.int32, torch.float, diff --git a/tests/py/dynamo/conversion/test_atan2_aten.py b/tests/py/dynamo/conversion/test_atan2_aten.py index b684a960ac..a35eb5409a 100644 --- a/tests/py/dynamo/conversion/test_atan2_aten.py +++ b/tests/py/dynamo/conversion/test_atan2_aten.py @@ -141,7 +141,7 @@ def forward(self, lhs_val, rhs_val): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, @@ -149,7 +149,7 @@ def forward(self, lhs_val, rhs_val): ( "3d_dim_dtype_int32", (1, 1, 1), - (1, 2, 4), + (2, 2, 4), (2, 3, 5), torch.int32, torch.float, @@ -220,7 +220,7 @@ def forward(self, lhs_val, rhs_val, out): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, diff --git a/tests/py/dynamo/conversion/test_atan_aten.py b/tests/py/dynamo/conversion/test_atan_aten.py index dda06404cb..2df58d4f69 100644 --- a/tests/py/dynamo/conversion/test_atan_aten.py +++ b/tests/py/dynamo/conversion/test_atan_aten.py @@ -58,7 +58,7 @@ def forward(self, input): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, @@ -66,7 +66,7 @@ def forward(self, input): ( "3d_dim_dtype_int32", (1, 1, 1), - (1, 2, 4), + (2, 2, 4), (2, 3, 5), torch.int32, torch.float, From 13361fda685301db70fcbc5b638b4d826cbe8dd8 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 22 Oct 2024 13:43:44 -0700 Subject: [PATCH 18/33] add linear lowering meta val --- py/torch_tensorrt/dynamo/_compiler.py | 27 +-------- .../dynamo/conversion/_conversion.py | 31 ++-------- .../dynamo/lowering/passes/lower_linear.py | 10 ++++ py/torch_tensorrt/dynamo/utils.py | 57 +++++++++++++++++++ tests/py/dynamo/models/test_models.py | 5 -- 5 files changed, 74 insertions(+), 56 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index cac615d6ae..0fe27868d0 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -36,6 +36,7 @@ ) from torch_tensorrt.dynamo.utils import ( get_flat_args_with_check, + get_output_meta_val, parse_graph_io, prepare_inputs, set_log_level, @@ -454,31 +455,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: # set the submodule meta val back to the parent trt_module_node outputs = [node for node in submodule.graph.nodes if node.op == "output"] outputs = outputs[0].args - outputs_meta_val = [] - for ele in outputs: - # it can be a torch.fx.node.Node or a tuple of torch.fx.node.Node - if isinstance(ele, torch.fx.node.Node): - if "val" not in ele.meta: - raise ValueError( - f"node.name={ele.name}: meta['val'] does not exist, expect submodule output node has meta['val'] info" - ) - outputs_meta_val.append(ele.meta["val"]) - elif isinstance(ele, tuple): - for node in ele: - if isinstance(node, torch.fx.node.Node): - if "val" not in node.meta: - raise ValueError( - f"{node.name=}: meta['val'] does not exist, expect submodule output node has meta['val'] info" - ) - outputs_meta_val.append(node.meta["val"]) - else: - raise ValueError( - f"expect torch.fx.node.Node type, got not expected types: {type(node)=}" - ) - else: - raise ValueError( - f"expect torch.fx.node.Node or tuple of torch.fx.node.Node type, got not expected types: {type(ele)=}" - ) + outputs_meta_val = get_output_meta_val(outputs) if name not in submodule_node_dict: raise ValueError( diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 33bbfe8bb8..ed2e990823 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -5,7 +5,6 @@ import tensorrt as trt import torch -from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype @@ -18,35 +17,15 @@ TRTInterpreterResult, ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule -from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs +from torch_tensorrt.dynamo.utils import ( + get_model_device, + get_output_dtypes, + get_torch_inputs, +) logger = logging.getLogger(__name__) -def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]: - output_dtypes = [] - if isinstance(output, torch.fx.node.Node): - if "val" in output.meta: - output_meta = output.meta["val"] - if isinstance(output_meta, (FakeTensor, torch.Tensor)): - if truncate_doulbe and output_meta.dtype == torch.float64: - output_dtypes.append(dtype.float32) - else: - output_dtypes.append(dtype._from(output_meta.dtype)) - else: - raise ValueError( - f"node.name={output.name}: node.meta['val'] does not exist, expect node.meta['val'] exists for each output node" - ) - elif isinstance(output, tuple): - for ele in output: - output_dtypes.extend(get_output_dtypes(ele)) - else: - raise ValueError( - f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple of torch.fx.node.Node" - ) - return output_dtypes - - def infer_module_output_dtypes( module: torch.fx.GraphModule, truncate_double: bool = False, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py index 9bd7ed8422..f1a884607e 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py @@ -6,6 +6,7 @@ from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, ) +from torch_tensorrt.dynamo.utils import get_output_meta_val, set_output_meta_val logger = logging.getLogger(__name__) @@ -14,12 +15,21 @@ def lower_linear( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: """Replace aten.linear with an equivalent implementation which can be easily converted to TRT""" + outputs = [node for node in gm.graph.nodes if node.op == "output"] + outputs = outputs[0].args + outputs_meta_val = get_output_meta_val(outputs) + orig, replacement = linear_replacement() if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement): gm = clean_up_graph_after_modifications(gm) logger.debug(f"Graph after lowering linear:\n{gm.graph}") + outputs = [node for node in gm.graph.nodes if node.op == "output"] + outputs = outputs[0].args + output_num = len(outputs_meta_val) + assert output_num > 0 + set_output_meta_val(outputs, outputs_meta_val) return gm diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index a85494239e..2c580244e4 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -661,3 +661,60 @@ def get_flat_args_with_check( flat_args_with_path, received_spec = pytree.tree_flatten_with_path((args, kwargs)) flat_args = tuple(x[1] for x in flat_args_with_path) return flat_args, received_spec + + +def get_output_meta_val(output: Any) -> List[Any]: + output_meta_val = [] + if isinstance(output, torch.fx.node.Node): + if "val" not in output.meta: + raise ValueError( + f"node.name={output.name}: meta['val'] does not exist, expect output node has meta['val'] info" + ) + output_meta_val.append(output.meta["val"]) + elif isinstance(output, tuple): + for node in output: + output_meta_val.extend(get_output_meta_val(node)) + else: + raise ValueError( + f"expect torch.fx.node.Node or a tuple of torch.fx.node.Node type, got unexpected types: {type(output)=}" + ) + return output_meta_val + + +def set_output_meta_val(output: Any, outputs_meta_val: List[Any]) -> None: + if isinstance(output, torch.fx.node.Node): + assert len(outputs_meta_val) > 0 + if "val" not in output.meta: + output.meta["val"] = outputs_meta_val[0] + outputs_meta_val.pop(0) + elif isinstance(output, tuple): + for node in output: + set_output_meta_val(node, outputs_meta_val) + else: + raise ValueError( + f"expect torch.fx.node.Node or a tuple of torch.fx.node.Node type, got unexpected types: {type(output)=}" + ) + + +def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]: + output_dtypes = [] + if isinstance(output, torch.fx.node.Node): + if "val" in output.meta: + output_meta = output.meta["val"] + if isinstance(output_meta, (FakeTensor, torch.Tensor)): + if truncate_doulbe and output_meta.dtype == torch.float64: + output_dtypes.append(dtype.float32) + else: + output_dtypes.append(dtype._from(output_meta.dtype)) + else: + raise ValueError( + f"node.name={output.name}: node.meta['val'] does not exist, expect node.meta['val'] exists for each output node" + ) + elif isinstance(output, tuple): + for ele in output: + output_dtypes.extend(get_output_dtypes(ele)) + else: + raise ValueError( + f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple of torch.fx.node.Node" + ) + return output_dtypes diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index ba6cb0c776..b6f986711a 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -29,7 +29,6 @@ def test_resnet18(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "ir": "torch_compile", "cache_built_engines": False, "reuse_cached_engines": False, } @@ -62,7 +61,6 @@ def test_mobilenet_v2(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 10, - "ir": "torch_compile", "cache_built_engines": False, "reuse_cached_engines": False, } @@ -95,7 +93,6 @@ def test_efficientnet_b0(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 10, - "ir": "torch_compile", "cache_built_engines": False, "reuse_cached_engines": False, } @@ -137,7 +134,6 @@ def test_bert_base_uncased(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 15, - "ir": "torch_compile", "cache_built_engines": False, "reuse_cached_engines": False, } @@ -173,7 +169,6 @@ def test_resnet18_half(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "ir": "torch_compile", "cache_built_engines": False, "reuse_cached_engines": False, } From f0a9fefefd00bfabb0c121cc9535e618dad22acd Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 23 Oct 2024 11:31:32 -0700 Subject: [PATCH 19/33] add linear_lowering change --- py/torch_tensorrt/dynamo/_compiler.py | 15 ++++++++------- .../dynamo/lowering/passes/lower_linear.py | 14 ++++++++------ py/torch_tensorrt/dynamo/utils.py | 9 +++------ tests/py/dynamo/conversion/harness.py | 2 +- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 0fe27868d0..a792ce4329 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -296,7 +296,6 @@ def compile( settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) - exported_program = pre_export_lowering(exported_program, settings) exported_program = exported_program.run_decompositions( get_decompositions(enable_experimental_decompositions) @@ -452,16 +451,18 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: ) continue - # set the submodule meta val back to the parent trt_module_node - outputs = [node for node in submodule.graph.nodes if node.op == "output"] - outputs = outputs[0].args - outputs_meta_val = get_output_meta_val(outputs) - if name not in submodule_node_dict: raise ValueError( f"node_name: {name} does not exist in the submodule node dictionary" ) - submodule_node_dict[name].meta["val"] = outputs_meta_val + + # set the submodule meta val back to the parent trt_module_node + if "val" not in submodule_node_dict[name].meta: + outputs = [node for node in submodule.graph.nodes if node.op == "output"] + outputs = outputs[0].args + outputs_meta_val = get_output_meta_val(outputs) + assert len(outputs_meta_val) > 0 + submodule_node_dict[name].meta["val"] = outputs_meta_val subgraph_data = PerSubgraphData() subgraph_data.subgraph_name = name diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py index f1a884607e..54e63a496f 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py @@ -15,21 +15,23 @@ def lower_linear( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: """Replace aten.linear with an equivalent implementation which can be easily converted to TRT""" + outputs = [node for node in gm.graph.nodes if node.op == "output"] outputs = outputs[0].args outputs_meta_val = get_output_meta_val(outputs) orig, replacement = linear_replacement() + replaced_nodes = torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement) - if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement): + if len(replaced_nodes) > 0: gm = clean_up_graph_after_modifications(gm) logger.debug(f"Graph after lowering linear:\n{gm.graph}") - outputs = [node for node in gm.graph.nodes if node.op == "output"] - outputs = outputs[0].args - output_num = len(outputs_meta_val) - assert output_num > 0 - set_output_meta_val(outputs, outputs_meta_val) + outputs = [node for node in gm.graph.nodes if node.op == "output"] + outputs = outputs[0].args + output_num = len(outputs_meta_val) + assert output_num > 0 + set_output_meta_val(outputs, outputs_meta_val) return gm diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 2c580244e4..b8e1489fc4 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -666,11 +666,8 @@ def get_flat_args_with_check( def get_output_meta_val(output: Any) -> List[Any]: output_meta_val = [] if isinstance(output, torch.fx.node.Node): - if "val" not in output.meta: - raise ValueError( - f"node.name={output.name}: meta['val'] does not exist, expect output node has meta['val'] info" - ) - output_meta_val.append(output.meta["val"]) + if "val" in output.meta: + output_meta_val.append(output.meta["val"]) elif isinstance(output, tuple): for node in output: output_meta_val.extend(get_output_meta_val(node)) @@ -686,7 +683,7 @@ def set_output_meta_val(output: Any, outputs_meta_val: List[Any]) -> None: assert len(outputs_meta_val) > 0 if "val" not in output.meta: output.meta["val"] = outputs_meta_val[0] - outputs_meta_val.pop(0) + outputs_meta_val.pop(0) elif isinstance(output, tuple): for node in output: set_output_meta_val(node, outputs_meta_val) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 610e424fef..bd4663f5db 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -262,7 +262,7 @@ def run_test( atol=ATOL, precision=dtype.f32, check_dtype=True, - use_dynamo_tracer=False, + use_dynamo_tracer=True, enable_passes=False, propagate_shapes=False, int32_reqd=False, From cff64a4a0470d5de8756886872b1a4a76f7f8ca4 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 23 Oct 2024 12:35:16 -0700 Subject: [PATCH 20/33] test --- tests/py/dynamo/conversion/harness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index bd4663f5db..610e424fef 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -262,7 +262,7 @@ def run_test( atol=ATOL, precision=dtype.f32, check_dtype=True, - use_dynamo_tracer=True, + use_dynamo_tracer=False, enable_passes=False, propagate_shapes=False, int32_reqd=False, From 933abac9690d4530486b5d7f8cfd24d0bb1f5efb Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 23 Oct 2024 15:35:39 -0700 Subject: [PATCH 21/33] test --- py/torch_tensorrt/dynamo/utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index b8e1489fc4..1b9491675a 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -668,12 +668,12 @@ def get_output_meta_val(output: Any) -> List[Any]: if isinstance(output, torch.fx.node.Node): if "val" in output.meta: output_meta_val.append(output.meta["val"]) - elif isinstance(output, tuple): + elif isinstance(output, (tuple, list)): for node in output: output_meta_val.extend(get_output_meta_val(node)) else: raise ValueError( - f"expect torch.fx.node.Node or a tuple of torch.fx.node.Node type, got unexpected types: {type(output)=}" + f"expect torch.fx.node.Node or a tuple/list of torch.fx.node.Node type, got unexpected types: {type(output)=}" ) return output_meta_val @@ -684,12 +684,12 @@ def set_output_meta_val(output: Any, outputs_meta_val: List[Any]) -> None: if "val" not in output.meta: output.meta["val"] = outputs_meta_val[0] outputs_meta_val.pop(0) - elif isinstance(output, tuple): + elif isinstance(output, (tuple, list)): for node in output: set_output_meta_val(node, outputs_meta_val) else: raise ValueError( - f"expect torch.fx.node.Node or a tuple of torch.fx.node.Node type, got unexpected types: {type(output)=}" + f"expect torch.fx.node.Node or a tuple/list of torch.fx.node.Node type, got unexpected types: {type(output)=}" ) @@ -707,11 +707,11 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype] raise ValueError( f"node.name={output.name}: node.meta['val'] does not exist, expect node.meta['val'] exists for each output node" ) - elif isinstance(output, tuple): + elif isinstance(output, (tuple, list)): for ele in output: output_dtypes.extend(get_output_dtypes(ele)) else: raise ValueError( - f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple of torch.fx.node.Node" + f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node" ) return output_dtypes From 841768410a767498d187a10a439e034212b73998 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 25 Oct 2024 07:50:03 -0700 Subject: [PATCH 22/33] resolve comments --- py/torch_tensorrt/dynamo/_compiler.py | 10 +--- py/torch_tensorrt/dynamo/_exporter.py | 2 +- .../dynamo/conversion/_conversion.py | 60 +------------------ .../dynamo/lowering/passes/lower_linear.py | 44 +++++--------- .../dynamo/lowering/passes/pass_utils.py | 23 +------ .../dynamo/lowering/passes/view_to_reshape.py | 3 +- py/torch_tensorrt/dynamo/utils.py | 58 ++++++++++++------ tests/py/dynamo/conversion/harness.py | 59 ++++++++++++++++-- 8 files changed, 115 insertions(+), 144 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index a792ce4329..2fefe86a08 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -455,14 +455,10 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: raise ValueError( f"node_name: {name} does not exist in the submodule node dictionary" ) - # set the submodule meta val back to the parent trt_module_node - if "val" not in submodule_node_dict[name].meta: - outputs = [node for node in submodule.graph.nodes if node.op == "output"] - outputs = outputs[0].args - outputs_meta_val = get_output_meta_val(outputs) - assert len(outputs_meta_val) > 0 - submodule_node_dict[name].meta["val"] = outputs_meta_val + metadata = get_output_meta_val(submodule) + assert len(metadata) > 0 + submodule_node_dict[name].meta = metadata subgraph_data = PerSubgraphData() subgraph_data.subgraph_name = name diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index 8bb67067b3..ae7c09caf8 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -368,7 +368,7 @@ def inline_trt_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: if "val" not in trt_module_node.meta: raise ValueError( - f"trt_module_node: {trt_module_node.name} does not have the meta['val'] info, it should be set during dynamo compile_module step." + f"trt_module_node: {trt_module_node.name} does not have the metadata which should be set during dynamo compile_module step." ) num_outputs = len(trt_module_node.meta["val"]) # Insert a call_function node to perform inference on TRT engine diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index ed2e990823..6dad862892 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -5,8 +5,6 @@ import tensorrt as trt import torch -from torch.fx.experimental.proxy_tensor import unset_fake_temporarily -from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input @@ -17,11 +15,7 @@ TRTInterpreterResult, ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule -from torch_tensorrt.dynamo.utils import ( - get_model_device, - get_output_dtypes, - get_torch_inputs, -) +from torch_tensorrt.dynamo.utils import get_output_dtypes logger = logging.getLogger(__name__) @@ -39,58 +33,6 @@ def infer_module_output_dtypes( return get_output_dtypes(outputs, truncate_double) -# this method is only used in our converter test to infer the module output dtypes via dummy inference -# which is due to fx.symbolic_trace does not have the meta['val'] info in the node -# TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace -def infer_module_output_dtypes_for_test( - module: torch.fx.GraphModule, - inputs: Sequence[Input], - device: Device, - kwarg_inputs: Optional[dict[str, Any]] = None, - truncate_double: bool = False, -) -> List[dtype]: - """ - This function performs model inference to determine the output dtypes - and truncates them accordingly. inputs can be either arg_inputs or flattened input list. - If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input. - """ - # TODO: We can also determine output dtypes from the module.graph based on node metadata. - # However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata, - # so we stick to the model inference approach currently. - with unset_fake_temporarily(): - # Get the device on which the model exists - # For large models, this can be done on CPU to save GPU memory allocation for TRT. - device = get_model_device(module) - torch_inputs = get_torch_inputs(inputs, device) - if kwarg_inputs is None: - kwarg_inputs = {} - torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) - module_outputs = module(*torch_inputs, **torch_kwarg_inputs) - if not isinstance(module_outputs, (list, tuple)): - module_outputs = [module_outputs] - - # Int64 outputs can sometimes be generated from within other operators - # such as aten.sum - such outputs can be truncated - output_dtypes = [] - for output in module_outputs: - output_ = output - # We don't need to check if output is nested here because the input module will be flattened - if not isinstance(output, torch.Tensor): - if isinstance(output, str): - raise ValueError( - f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)" - ) - else: - output_ = torch.tensor(output) - - if truncate_double and output_.dtype == dtype.float64: - output_dtypes.append(dtype.float32) - else: - output_dtypes.append(dtype._from(output_.dtype)) - - return output_dtypes - - def interpret_module_to_result( module: torch.fx.GraphModule, inputs: Sequence[Input], diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py index 54e63a496f..dca3d9ed47 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py @@ -1,12 +1,11 @@ import logging -from typing import Callable, Tuple import torch from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, ) -from torch_tensorrt.dynamo.utils import get_output_meta_val, set_output_meta_val +from torch_tensorrt.dynamo.utils import get_metadata, set_metadata logger = logging.getLogger(__name__) @@ -15,44 +14,29 @@ def lower_linear( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: """Replace aten.linear with an equivalent implementation which can be easily converted to TRT""" - - outputs = [node for node in gm.graph.nodes if node.op == "output"] - outputs = outputs[0].args - outputs_meta_val = get_output_meta_val(outputs) - - orig, replacement = linear_replacement() - replaced_nodes = torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement) - - if len(replaced_nodes) > 0: - gm = clean_up_graph_after_modifications(gm) - logger.debug(f"Graph after lowering linear:\n{gm.graph}") - - outputs = [node for node in gm.graph.nodes if node.op == "output"] - outputs = outputs[0].args - output_num = len(outputs_meta_val) - assert output_num > 0 - set_output_meta_val(outputs, outputs_meta_val) - return gm - - -def linear_replacement() -> Tuple[ - torch.fx.GraphModule, - Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], -]: - """Constructs the original and replacement functions for linear""" + orig_op = torch.ops.aten.addmm.default + replacement_op = torch.ops.aten.linear.default # Original graph def orig( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor ) -> torch.Tensor: W_T = torch.ops.aten.permute.default(weight, [1, 0]) - out = torch.ops.aten.addmm.default(bias, input, W_T) + out = orig_op(bias, input, W_T) return out # Replacement graph def replacement( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor ) -> torch.Tensor: - return torch.ops.aten.linear.default(input, weight, bias) + return replacement_op(input, weight, bias) - return orig, replacement + metadata = get_metadata(gm, orig_op) + replaced_nodes = torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement) + + if len(replaced_nodes) > 0: + gm = clean_up_graph_after_modifications(gm) + set_metadata(gm, replacement_op, metadata) + logger.debug(f"Graph after lowering linear:\n{gm.graph}") + + return gm diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py index 0ffc6d3c76..31a55099c2 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import List import torch @@ -29,24 +29,3 @@ def get_tensor_placeholders( ] return placeholders - - -def get_metadata( - gm: torch.fx.GraphModule, target_op: Any -) -> List[torch._ops.OpOverload]: - """ - Return the list which has the metadata of all the target_op nodes present in the graph. - """ - return [node.meta for node in gm.graph.nodes if node.target == target_op] - - -def set_metadata( - gm: torch.fx.GraphModule, target_op: Any, metadata: List[torch._ops.OpOverload] -) -> None: - """ - Return the list which has the metadata of all the target_op nodes present in the graph. - """ - target_nodes = [node for node in gm.graph.nodes if node.target == target_op] - assert len(target_nodes) == len(metadata) - for idx, node in enumerate(target_nodes): - node.meta = metadata[idx] diff --git a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py index 06632db623..4464555261 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py @@ -5,9 +5,8 @@ from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, - get_metadata, - set_metadata, ) +from torch_tensorrt.dynamo.utils import get_metadata, set_metadata logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 1b9491675a..ac6c2a936d 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -663,34 +663,51 @@ def get_flat_args_with_check( return flat_args, received_spec -def get_output_meta_val(output: Any) -> List[Any]: - output_meta_val = [] +def get_metadata( + gm: torch.fx.GraphModule, target_op: torch._ops.OpOverload +) -> List[Any]: + """ + Return the list which has the metadata of all the target_op nodes present in the graph. + """ + return [node.meta for node in gm.graph.nodes if node.target == target_op] + + +def set_metadata( + gm: torch.fx.GraphModule, target_op: torch._ops.OpOverload, metadata: List[Any] +) -> None: + """ + Return the list which has the metadata of all the target_op nodes present in the graph. + """ + target_nodes = [node for node in gm.graph.nodes if node.target == target_op] + assert len(target_nodes) == len(metadata) + for idx, node in enumerate(target_nodes): + node.meta = metadata[idx] + + +def get_output_target_ops(output: Any) -> List[torch._ops.OpOverload]: + ret = [] if isinstance(output, torch.fx.node.Node): if "val" in output.meta: - output_meta_val.append(output.meta["val"]) + ret.append(output.target) elif isinstance(output, (tuple, list)): for node in output: - output_meta_val.extend(get_output_meta_val(node)) + ret.extend(get_output_target_ops(node)) else: raise ValueError( f"expect torch.fx.node.Node or a tuple/list of torch.fx.node.Node type, got unexpected types: {type(output)=}" ) - return output_meta_val + return ret -def set_output_meta_val(output: Any, outputs_meta_val: List[Any]) -> None: - if isinstance(output, torch.fx.node.Node): - assert len(outputs_meta_val) > 0 - if "val" not in output.meta: - output.meta["val"] = outputs_meta_val[0] - outputs_meta_val.pop(0) - elif isinstance(output, (tuple, list)): - for node in output: - set_output_meta_val(node, outputs_meta_val) - else: - raise ValueError( - f"expect torch.fx.node.Node or a tuple/list of torch.fx.node.Node type, got unexpected types: {type(output)=}" - ) +def get_output_meta_val( + gm: torch.fx.GraphModule, +) -> List[Any]: + outputs = [node for node in gm.graph.nodes if node.op == "output"] + assert len(outputs) > 0 + outputs = outputs[0].args + target_ops = get_output_target_ops(outputs) + assert len(target_ops) > 0 + return get_metadata(gm, target_ops[0]) def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]: @@ -703,9 +720,12 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype] output_dtypes.append(dtype.float32) else: output_dtypes.append(dtype._from(output_meta.dtype)) + elif "tensor_meta" in output.meta: + output_meta = output.meta["tensor_meta"] + output_dtypes.append(dtype._from(output_meta.dtype)) else: raise ValueError( - f"node.name={output.name}: node.meta['val'] does not exist, expect node.meta['val'] exists for each output node" + f"node.name={output.name}: metadata does not exist, expect metadata exists for each output node" ) elif isinstance(output, (tuple, list)): for ele in output: diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 610e424fef..e648fb1e53 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -3,22 +3,21 @@ import logging import time import unittest -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Sequence, Tuple import torch import torch_tensorrt +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.passes.shape_prop import ShapeProp from torch.testing._internal.common_utils import TestCase from torch_tensorrt import Input +from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._settings import CompilationSettings # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter -from torch_tensorrt.dynamo.conversion._conversion import ( - infer_module_output_dtypes_for_test, -) from torch_tensorrt.dynamo.lowering import ( get_decompositions, post_lowering, @@ -30,6 +29,58 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) +# this method is only used in our converter test to infer the module output dtypes via dummy inference +# which is due to fx.symbolic_trace does not have the meta['val'] info in the node +# TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace +def infer_module_output_dtypes_for_test( + module: torch.fx.GraphModule, + inputs: Sequence[Input], + device: Device, + kwarg_inputs: Optional[dict[str, Any]] = None, + truncate_double: bool = False, +) -> List[dtype]: + """ + This function performs model inference to determine the output dtypes + and truncates them accordingly. inputs can be either arg_inputs or flattened input list. + If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input. + """ + # TODO: We can also determine output dtypes from the module.graph based on node metadata. + # However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata, + # so we stick to the model inference approach currently. + with unset_fake_temporarily(): + # Get the device on which the model exists + # For large models, this can be done on CPU to save GPU memory allocation for TRT. + device = get_model_device(module) + torch_inputs = get_torch_inputs(inputs, device) + if kwarg_inputs is None: + kwarg_inputs = {} + torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) + module_outputs = module(*torch_inputs, **torch_kwarg_inputs) + if not isinstance(module_outputs, (list, tuple)): + module_outputs = [module_outputs] + + # Int64 outputs can sometimes be generated from within other operators + # such as aten.sum - such outputs can be truncated + output_dtypes = [] + for output in module_outputs: + output_ = output + # We don't need to check if output is nested here because the input module will be flattened + if not isinstance(output, torch.Tensor): + if isinstance(output, str): + raise ValueError( + f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)" + ) + else: + output_ = torch.tensor(output) + + if truncate_double and output_.dtype == dtype.float64: + output_dtypes.append(dtype.float32) + else: + output_dtypes.append(dtype._from(output_.dtype)) + + return output_dtypes + + def fetch_attr(mod, target): """ Fetch an attribute from the ``Module`` hierarchy of ``mod.module``. From 8676f886f8962059316f1a45e6bd949d345880e6 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 25 Oct 2024 10:48:21 -0700 Subject: [PATCH 23/33] test --- py/torch_tensorrt/dynamo/_compiler.py | 18 +++++++++++++----- py/torch_tensorrt/dynamo/utils.py | 23 +++++++++++------------ 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 2fefe86a08..c183c3abfb 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -36,7 +36,7 @@ ) from torch_tensorrt.dynamo.utils import ( get_flat_args_with_check, - get_output_meta_val, + get_output_metadata, parse_graph_io, prepare_inputs, set_log_level, @@ -455,10 +455,18 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: raise ValueError( f"node_name: {name} does not exist in the submodule node dictionary" ) - # set the submodule meta val back to the parent trt_module_node - metadata = get_output_meta_val(submodule) - assert len(metadata) > 0 - submodule_node_dict[name].meta = metadata + + # set the submodule metadata back to the parent trt_module_node + metadata_list = get_output_metadata(submodule) + assert len(metadata_list) > 0 + if "val" not in submodule_node_dict[name].meta: + meta_val_list = [ + metadata["val"] for metadata in metadata_list if "val" in metadata + ] + submodule_node_dict[name].meta["val"] = meta_val_list + logger.debug( + f"Update submodule output metadata back to the parent trt_module_node: {name}" + ) subgraph_data = PerSubgraphData() subgraph_data.subgraph_name = name diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index ac6c2a936d..3a1c339074 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -684,30 +684,29 @@ def set_metadata( node.meta = metadata[idx] -def get_output_target_ops(output: Any) -> List[torch._ops.OpOverload]: +def flatten_nodes(nodes: Any) -> List[torch.fx.node.Node]: ret = [] - if isinstance(output, torch.fx.node.Node): - if "val" in output.meta: - ret.append(output.target) - elif isinstance(output, (tuple, list)): - for node in output: - ret.extend(get_output_target_ops(node)) + if isinstance(nodes, torch.fx.node.Node): + ret.append(nodes) + elif isinstance(nodes, (tuple, list)): + for node in nodes: + ret.extend(flatten_nodes(node)) else: raise ValueError( - f"expect torch.fx.node.Node or a tuple/list of torch.fx.node.Node type, got unexpected types: {type(output)=}" + f"expect torch.fx.node.Node or a tuple/list of torch.fx.node.Node type, got unexpected types: {type(nodes)=}" ) return ret -def get_output_meta_val( +def get_output_metadata( gm: torch.fx.GraphModule, ) -> List[Any]: outputs = [node for node in gm.graph.nodes if node.op == "output"] assert len(outputs) > 0 outputs = outputs[0].args - target_ops = get_output_target_ops(outputs) - assert len(target_ops) > 0 - return get_metadata(gm, target_ops[0]) + nodes = flatten_nodes(outputs) + assert len(nodes) > 0 + return [node.meta for node in nodes] def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]: From d8e52bf97e9814fe711023114566dfcfaa73281f Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sat, 26 Oct 2024 19:04:53 -0700 Subject: [PATCH 24/33] test --- py/torch_tensorrt/dynamo/utils.py | 8 +- tests/py/dynamo/conversion/harness.py | 25 +++-- tests/py/dynamo/conversion/test_atan2_aten.py | 16 +++- tests/py/dynamo/conversion/test_atanh_aten.py | 4 +- tests/py/dynamo/conversion/test_attention.py | 93 ++++++++++++++----- 5 files changed, 111 insertions(+), 35 deletions(-) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index b8e1489fc4..a20d402487 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -695,7 +695,13 @@ def set_output_meta_val(output: Any, outputs_meta_val: List[Any]) -> None: def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]: output_dtypes = [] - if isinstance(output, torch.fx.node.Node): + if isinstance(output, int): + output_dtypes.append(dtype.int64) + elif isinstance(output, bool): + output_dtypes.append(dtype.bool) + elif isinstance(output, float): + output_dtypes.append(dtype.float32) + elif isinstance(output, torch.fx.node.Node): if "val" in output.meta: output_meta = output.meta["val"] if isinstance(output_meta, (FakeTensor, torch.Tensor)): diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index f6251f3913..c7906e49d0 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -33,7 +33,7 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -# this is to enable dynamo tracer as Truein the converter test files batch by batch +# this is to enable dynamo tracer as True in the converter test files batch by batch def get_use_dynamo_tracer(use_dynamo_tracer: Any) -> bool: # if in our converter tests we specifically set use_dynamo_tracer field, honor it if use_dynamo_tracer is not None and isinstance(use_dynamo_tracer, bool): @@ -296,12 +296,15 @@ def run_test( atol=ATOL, precision=dtype.f32, check_dtype=True, + use_dynamo_tracer=None, enable_passes=False, propagate_shapes=False, int32_reqd=False, make_refittable=False, ): - + # TODO: lan to remove this and set use_dynamo_traccer to True by default + # once all the converter test files are moved to use_dynamo_tracer + use_dynamo_tracer = get_use_dynamo_tracer(use_dynamo_tracer) # Previous instance of the interpreter auto-casted 64-bit inputs # We replicate this behavior here compilation_settings = CompilationSettings( @@ -314,7 +317,7 @@ def run_test( mod = self.generate_graph( mod, inputs, - use_dynamo_tracer=True, + use_dynamo_tracer=use_dynamo_tracer, enable_passes=enable_passes, propagate_shapes=propagate_shapes, settings=compilation_settings, @@ -348,10 +351,18 @@ def run_test( output_dtypes = None if check_dtype: - output_dtypes = infer_module_output_dtypes( - mod, - truncate_double=compilation_settings.truncate_double, - ) + if use_dynamo_tracer: + output_dtypes = infer_module_output_dtypes( + mod, + truncate_double=compilation_settings.truncate_double, + ) + else: + output_dtypes = infer_module_output_dtypes_for_test( + mod, + input_specs, + compilation_settings.device, + truncate_double=compilation_settings.truncate_double, + ) _LOGGER.debug(f"Compilation settings: {compilation_settings}") _LOGGER.debug(f"Inputs: {input_specs}") diff --git a/tests/py/dynamo/conversion/test_atan2_aten.py b/tests/py/dynamo/conversion/test_atan2_aten.py index a35eb5409a..2db650cbaa 100644 --- a/tests/py/dynamo/conversion/test_atan2_aten.py +++ b/tests/py/dynamo/conversion/test_atan2_aten.py @@ -1,3 +1,5 @@ +import unittest + import torch import torch.nn as nn from parameterized import parameterized @@ -182,10 +184,17 @@ def forward(self, lhs_val, rhs_val): ) +# torch.ops.aten.atan2.out will be decomposed/partitioned into core aten ops which torch_tensorrt supported in run_on_acc and +# non supported ops in run_on_gpu in dynamo tracer, it works via torch_tensorrt.dynamo.compile workflow +# but it won't be valid for our converter test framework, so skip it here. +@unittest.skip("skip torch.ops.aten.atan2.out converter test") class TestAtan2OutConverter(DispatchTestCase): @parameterized.expand( [ - ((10,), (5,), torch.float), + # dynamo trace does not allow output to be in a different shape + # raise Unsupported(msg, case_name=case_name) + # torch._dynamo.exc.Unsupported: out variants with resizing on graph inputs + ((5,), (5,), torch.float), ((10,), (10,), torch.float), ] ) @@ -255,7 +264,10 @@ def forward(self, lhs_val, rhs_val, out): ), ] self.run_test_with_dynamic_shape( - atan2(), input_specs, output_dtypes=[output_type] + atan2(), + input_specs, + output_dtypes=[output_type], + use_dynamo_tracer=False, ) diff --git a/tests/py/dynamo/conversion/test_atanh_aten.py b/tests/py/dynamo/conversion/test_atanh_aten.py index 3b7ac77541..2438bf8252 100644 --- a/tests/py/dynamo/conversion/test_atanh_aten.py +++ b/tests/py/dynamo/conversion/test_atanh_aten.py @@ -58,7 +58,7 @@ def forward(self, input): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, @@ -66,7 +66,7 @@ def forward(self, input): ( "3d_dim_dtype_int32", (1, 1, 1), - (1, 2, 4), + (2, 2, 4), (2, 3, 5), torch.int32, torch.float, diff --git a/tests/py/dynamo/conversion/test_attention.py b/tests/py/dynamo/conversion/test_attention.py index 5109a2e2ca..bc41206df8 100644 --- a/tests/py/dynamo/conversion/test_attention.py +++ b/tests/py/dynamo/conversion/test_attention.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn from parameterized import parameterized +from torch.export import Dim from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input @@ -24,33 +25,41 @@ def forward(self, query, key, value): key = torch.rand(key_shape, dtype=torch.float16) value = torch.rand(key_shape, dtype=torch.float16) inputs.extend([query, key, value]) - self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16) + self.run_test( + SDPA(), + inputs, + rtol=1e-2, + atol=1e-2, + precision=torch.float16, + enable_passes=True, + ) + @unittest.skip("need to change to custom dynamic shapes") @parameterized.expand( [ - ( - "4d-2d", - (4, 2, 16, 32), - (6, 3, 32, 64), - (32, 8, 64, 128), - (4, 32), - (4, 64), - (16, 128), - ), - ( - "4d-3d", - (2, 2, 2, 2), - (3, 3, 3, 4), - (3, 4, 4, 5), - (2, 3, 2), - (3, 3, 4), - (4, 5, 5), - ), + # ( + # "4d-2d", + # (4, 2, 16, 32), + # (6, 3, 32, 64), + # (32, 8, 64, 128), + # (4, 32), + # (4, 64), + # (16, 128), + # ), + # ( + # "4d-3d", + # (2, 2, 2, 2), + # (3, 3, 3, 4), + # (3, 4, 4, 5), + # (2, 3, 2), + # (3, 3, 4), + # (4, 5, 5), + # ), ( "4d-4d", - (4, 2, 12, 16), - (6, 3, 16, 32), - (32, 8, 18, 64), + (4, 2, 12, 4), + (6, 3, 16, 8), + (32, 8, 18, 16), (4, 2, 4, 16), (6, 3, 8, 32), (32, 8, 12, 64), @@ -102,9 +111,42 @@ def forward(self, query, key, value): max_shape=key_max_shape, ), ] + dyn_dim_0 = Dim("dyn_dim_0", min=4, max=32) + dyn_dim_1 = Dim("dyn_dim_1", min=2, max=8) - self.run_test_with_dynamic_shape(SDPA(), inputs) + q_dyn_dim_2 = Dim("q_dyn_dim_2", min=12, max=18) + q_dyn_dim_3 = Dim("q_dyn_dim_3", min=4, max=16) + + k_dyn_dim_2 = Dim("k_dyn_dim_2", min=4, max=12) + k_dyn_dim_3 = 4 * q_dyn_dim_3 # Dim("k_dyn_dim_3", min=16, max=64) + torch_export_dynamic_shapes = {} + torch_export_dynamic_shapes["query"] = { + 0: dyn_dim_0, + 1: dyn_dim_1, + 2: q_dyn_dim_2, + 3: q_dyn_dim_3, + } + torch_export_dynamic_shapes["key"] = { + 0: dyn_dim_0, + 1: dyn_dim_1, + 2: k_dyn_dim_2, + 3: k_dyn_dim_3, + } + torch_export_dynamic_shapes["value"] = { + 0: dyn_dim_0, + 1: dyn_dim_1, + 2: k_dyn_dim_2, + 3: k_dyn_dim_3, + } + self.run_test_with_dynamic_shape( + SDPA(), + inputs, + torch_export_dynamic_shapes=torch_export_dynamic_shapes, + enable_passes=True, + ) + + @unittest.skip("need to change to custom dynamic shapes") @parameterized.expand( [ ( @@ -175,6 +217,7 @@ def forward(self, query, key, value): self.run_test_with_dynamic_shape(SDPA(), inputs) + @unittest.skip("need to change to custom dynamic shapes") @parameterized.expand( [ ( @@ -252,6 +295,10 @@ def forward(self, query, key, value): self.run_test_with_dynamic_shape(SDPA(), inputs) + # it is already added in the integration test + @unittest.skip( + "skip torch.nn.functional.scaled_dot_product_attention converter test" + ) @parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))]) def test_sdpa_causal(self, query_shape, key_shape): class SDPA(nn.Module): From 8b3842a2a1adc45a4af89728873a42347c845a6c Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sat, 26 Oct 2024 19:13:49 -0700 Subject: [PATCH 25/33] test --- tests/py/dynamo/conversion/harness.py | 24 +++++++++++++++---- tests/py/dynamo/conversion/test_asinh_aten.py | 6 ++--- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 20fab0e73c..0bef97d077 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -20,10 +20,7 @@ # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter -from torch_tensorrt.dynamo.conversion._conversion import ( - infer_module_output_dtypes, - infer_module_output_dtypes_for_test, -) +from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes from torch_tensorrt.dynamo.lowering import ( get_decompositions, post_lowering, @@ -35,6 +32,25 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) +# this is to enable dynamo tracer as True in the converter test files batch by batch +def get_use_dynamo_tracer(use_dynamo_tracer: Any) -> bool: + # if in our converter tests we specifically set use_dynamo_tracer field, honor it + if use_dynamo_tracer is not None and isinstance(use_dynamo_tracer, bool): + return use_dynamo_tracer + # if in our converter tests, we did not specify use_dynamo_tracer field + import inspect + import os + import re + + filename = os.path.basename(inspect.stack()[2].filename) + # enable converter test files which starts with test_a*.py to use dynamo tracer + pattern = re.compile("^test_([a])+") + if pattern.match(filename): + return True + else: + return False + + # this method is only used in our converter test to infer the module output dtypes via dummy inference # which is due to fx.symbolic_trace does not have the meta['val'] info in the node # TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace diff --git a/tests/py/dynamo/conversion/test_asinh_aten.py b/tests/py/dynamo/conversion/test_asinh_aten.py index 025cdbefb5..ca9b7e9126 100644 --- a/tests/py/dynamo/conversion/test_asinh_aten.py +++ b/tests/py/dynamo/conversion/test_asinh_aten.py @@ -11,9 +11,9 @@ class TestAsinhConverter(DispatchTestCase): @parameterized.expand( [ ((10,), torch.float), - # ((1, 20), torch.float), - # ((2, 3, 4), torch.float), - # ((2, 3, 4, 5), torch.float), + ((1, 20), torch.float), + ((2, 3, 4), torch.float), + ((2, 3, 4, 5), torch.float), ] ) def test_asinh_float(self, input_shape, dtype): From 7ddf56f27b9886c91b3c1c4de91d4b864283fc47 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sun, 27 Oct 2024 10:01:43 -0700 Subject: [PATCH 26/33] test --- tests/py/dynamo/conversion/test_attention.py | 46 ++--- .../conversion/test_composite_aten_op.py | 194 ++++++++++++++++++ 2 files changed, 217 insertions(+), 23 deletions(-) create mode 100644 tests/py/dynamo/conversion/test_composite_aten_op.py diff --git a/tests/py/dynamo/conversion/test_attention.py b/tests/py/dynamo/conversion/test_attention.py index bc41206df8..984a5a3093 100644 --- a/tests/py/dynamo/conversion/test_attention.py +++ b/tests/py/dynamo/conversion/test_attention.py @@ -34,27 +34,26 @@ def forward(self, query, key, value): enable_passes=True, ) - @unittest.skip("need to change to custom dynamic shapes") @parameterized.expand( [ - # ( - # "4d-2d", - # (4, 2, 16, 32), - # (6, 3, 32, 64), - # (32, 8, 64, 128), - # (4, 32), - # (4, 64), - # (16, 128), - # ), - # ( - # "4d-3d", - # (2, 2, 2, 2), - # (3, 3, 3, 4), - # (3, 4, 4, 5), - # (2, 3, 2), - # (3, 3, 4), - # (4, 5, 5), - # ), + ( + "4d-2d", + (4, 2, 16, 32), + (6, 3, 32, 64), + (32, 8, 64, 128), + (4, 32), + (4, 64), + (16, 128), + ), + ( + "4d-3d", + (2, 2, 2, 2), + (3, 3, 3, 4), + (3, 4, 4, 5), + (2, 3, 2), + (3, 3, 4), + (4, 5, 5), + ), ( "4d-4d", (4, 2, 12, 4), @@ -66,6 +65,7 @@ def forward(self, query, key, value): ), ] ) + @unittest.skip("need to change to custom dynamic shapes") def test_sdpa_no_causal_dynamic_shape_with_scale( self, _, @@ -146,7 +146,6 @@ def forward(self, query, key, value): enable_passes=True, ) - @unittest.skip("need to change to custom dynamic shapes") @parameterized.expand( [ ( @@ -169,6 +168,7 @@ def forward(self, query, key, value): ), ] ) + @unittest.skip("need to change to custom dynamic shapes") def test_sdpa_no_causal_no_scale_dynamic_shape( self, _, @@ -217,7 +217,6 @@ def forward(self, query, key, value): self.run_test_with_dynamic_shape(SDPA(), inputs) - @unittest.skip("need to change to custom dynamic shapes") @parameterized.expand( [ ( @@ -252,6 +251,7 @@ def forward(self, query, key, value): ), ] ) + @unittest.skip("need to change to custom dynamic shapes") def test_sdpa_causal_dynamic_shape( self, _, @@ -295,11 +295,11 @@ def forward(self, query, key, value): self.run_test_with_dynamic_shape(SDPA(), inputs) - # it is already added in the integration test + @parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))]) + # it is already added in the test_composite_aten_op.py as integration test @unittest.skip( "skip torch.nn.functional.scaled_dot_product_attention converter test" ) - @parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))]) def test_sdpa_causal(self, query_shape, key_shape): class SDPA(nn.Module): def forward(self, query, key, value): diff --git a/tests/py/dynamo/conversion/test_composite_aten_op.py b/tests/py/dynamo/conversion/test_composite_aten_op.py new file mode 100644 index 0000000000..769e26983c --- /dev/null +++ b/tests/py/dynamo/conversion/test_composite_aten_op.py @@ -0,0 +1,194 @@ +import unittest + +import pytest +import torch +import torch.nn as nn +import torch_tensorrt as torchtrt +from parameterized import parameterized +from torch.export import Dim +from torch_tensorrt import Input +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +assertions = unittest.TestCase() + + +@pytest.mark.unit +@parameterized.expand( + [ + ((5,), (5,)), + ( + ( + 2, + 3, + ), + ( + 2, + 3, + ), + ), + ] +) +def test_atan2_out_static_shape(input_shape, out_shape): + class atan2(torch.nn.Module): + def forward(self, lhs_val, rhs_val, out): + return torch.ops.aten.atan2.out(lhs_val, rhs_val, out=out) + + model = atan2().eval().cuda() + inputs = ( + torch.randn(input_shape).cuda(), + torch.randn(input_shape).cuda(), + torch.randn(out_shape).cuda(), + ) + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + } + + trt_model = torchtrt.compile(model, **compile_spec) + py_outputs = model(*inputs) + trt_outputs = trt_model(*inputs) + cos_sim = cosine_similarity(py_outputs, trt_outputs) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_atan2_out_static_shape model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +@parameterized.expand( + [ + ( + ( + 1, + 2, + ), + (2, 3), + (2, 4), + ), + ] +) +def test_atan2_out_dynamic_shape(min_shape, opt_shape, max_shape): + class atan2(torch.nn.Module): + def forward(self, lhs_val, rhs_val, out): + return torch.ops.aten.atan2.out(lhs_val, rhs_val, out=out) + + model = atan2().eval().cuda() + input_spec = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + + compile_spec = { + "inputs": input_spec, + "ir": "dynamo", + "min_block_size": 1, + } + + trt_model = torchtrt.compile(model, **compile_spec) + inputs = ( + torch.randn(max_shape).cuda(), + torch.randn(max_shape).cuda(), + torch.randn(max_shape).cuda(), + ) + py_outputs = model(*inputs) + trt_outputs = trt_model(*inputs) + cos_sim = cosine_similarity(py_outputs, trt_outputs) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_atan2_out_dynamic_shape model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@parameterized.expand( + [ + ((32, 8, 128, 64), (32, 8, 128, 64), True, None), + ((32, 32, 128, 64), (32, 8, 128, 64), True, 0.1), + ] +) +def test_sdpa_static_shape(query_shape, key_shape, is_causal, scale): + class SDPA(nn.Module): + def forward(self, query, key, value): + return torch.nn.functional.scaled_dot_product_attention( + query, key, value, None, 0.0, is_causal=is_causal, scale=scale + ) + + model = SDPA().eval().cuda() + + query = torch.randn(query_shape, dtype=torch.float16).cuda() + key = torch.randn(key_shape, dtype=torch.float16).cuda() + value = torch.randn(key_shape, dtype=torch.float16).cuda() + inputs = (query, key, value) + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + } + + trt_model = torchtrt.compile(model, **compile_spec) + py_outputs = model(*inputs) + trt_outputs = trt_model(*inputs) + cos_sim = cosine_similarity(py_outputs, trt_outputs) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_sdpa_static_shape model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@parameterized.expand( + [ + (True, None), + (True, 0.1), + (False, None), + ] +) +def test_sdpa_dynamic_shape(is_causal, scale): + class SDPA(nn.Module): + def forward(self, query, key, value): + return torch.nn.functional.scaled_dot_product_attention( + query, key, value, None, 0.0, is_causal=is_causal, scale=scale + ) + + model = SDPA().eval().cuda() + + # N: batch_size + dyn_N = Dim("dyn_N", min=2, max=4) + + # query tensor shape (N, ..., Hq, L, E) + query = torch.randn((3, 3, 4, 64), dtype=torch.float16).cuda() + # key tensor shape (N,...,H, S, E) + key = torch.randn((3, 3, 4, 64), dtype=torch.float16).cuda() + # value tensor shape (N, ..., H, S, Ev) + value = torch.randn((3, 3, 4, 64), dtype=torch.float16).cuda() + + dynamic_shapes = {"query": {0: dyn_N}, "key": {0: dyn_N}, "value": {0: dyn_N}} + inputs = (query, key, value) + + exp_program = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + } + trt_model = torchtrt.dynamo.compile(exp_program, **compile_spec) + py_outputs = model(*inputs) + trt_outputs = trt_model(*inputs) + cos_sim = cosine_similarity(py_outputs, trt_outputs) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_sdpa_dynamic_shape model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) From 39e0a49a23c8259df45ca43c6dfa3587397aea80 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sun, 27 Oct 2024 14:34:27 -0700 Subject: [PATCH 27/33] test --- tests/py/dynamo/conversion/test_composite_aten_op.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/py/dynamo/conversion/test_composite_aten_op.py b/tests/py/dynamo/conversion/test_composite_aten_op.py index 769e26983c..b2f13fb9c6 100644 --- a/tests/py/dynamo/conversion/test_composite_aten_op.py +++ b/tests/py/dynamo/conversion/test_composite_aten_op.py @@ -116,7 +116,6 @@ def forward(self, lhs_val, rhs_val, out): @parameterized.expand( [ ((32, 8, 128, 64), (32, 8, 128, 64), True, None), - ((32, 32, 128, 64), (32, 8, 128, 64), True, 0.1), ] ) def test_sdpa_static_shape(query_shape, key_shape, is_causal, scale): From 076f47ac62e5ace1797f9be6fbd054871284cb76 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 29 Oct 2024 09:02:36 -0700 Subject: [PATCH 28/33] resolve comments --- py/torch_tensorrt/dynamo/_compiler.py | 12 +++++++++++- py/torch_tensorrt/dynamo/_tracer.py | 3 --- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c183c3abfb..7ae8fa07a3 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -465,7 +465,17 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: ] submodule_node_dict[name].meta["val"] = meta_val_list logger.debug( - f"Update submodule output metadata back to the parent trt_module_node: {name}" + f"Updated val metadata for node: {name} with its corresponding submodule outputs" + ) + if "tensor_meta" not in submodule_node_dict[name].meta: + tensor_meta_list = [ + metadata["tensor_meta"] + for metadata in metadata_list + if "tensor_meta" in metadata + ] + submodule_node_dict[name].meta["tensor_meta"] = tensor_meta_list + logger.debug( + f"Updated tensor_meta metadata for node: {name} with its corresponding submodule outputs" ) subgraph_data = PerSubgraphData() diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 2c8745ee33..78f7989777 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -115,9 +115,6 @@ def get_dynamic_shapes_args(mod: torch.nn.Module, inputs: Any) -> dict[str, Any] args = list(signature(mod.forward).parameters.keys()) dynamic_shapes = {} for input, input_name in zip(inputs, args[: len(inputs)]): - # if input.name is not None, also not empty str, use the input.name - if input.name is not None and len(input.name) > 0 and input.name != input_name: - input_name = input.name dynamic_shapes[input_name] = get_dynamic_shapes(input) return dynamic_shapes From 96e93e4d71034c1f705452d35b9dc2441944cf2f Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 29 Oct 2024 10:45:02 -0700 Subject: [PATCH 29/33] resolve comments --- py/torch_tensorrt/dynamo/_compiler.py | 29 ++++++++++----------------- 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 3ac6bf1555..6b7335b0a3 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -466,24 +466,17 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: # set the submodule metadata back to the parent trt_module_node metadata_list = get_output_metadata(submodule) assert len(metadata_list) > 0 - if "val" not in submodule_node_dict[name].meta: - meta_val_list = [ - metadata["val"] for metadata in metadata_list if "val" in metadata - ] - submodule_node_dict[name].meta["val"] = meta_val_list - logger.debug( - f"Updated val metadata for node: {name} with its corresponding submodule outputs" - ) - if "tensor_meta" not in submodule_node_dict[name].meta: - tensor_meta_list = [ - metadata["tensor_meta"] - for metadata in metadata_list - if "tensor_meta" in metadata - ] - submodule_node_dict[name].meta["tensor_meta"] = tensor_meta_list - logger.debug( - f"Updated tensor_meta metadata for node: {name} with its corresponding submodule outputs" - ) + metadata_keys = ["val", "tensor_meta"] + for key in metadata_keys: + if key not in submodule_node_dict[name].meta: + meta_val_list = [ + metadata[key] for metadata in metadata_list if key in metadata + ] + submodule_node_dict[name].meta[key] = meta_val_list + logger.debug( + f"Updated metadata for node: {name} with its corresponding submodule outputs" + ) + break subgraph_data = PerSubgraphData() subgraph_data.subgraph_name = name From c0237141a696694c1bdaac6c0fd4d4b501e9e3a4 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 30 Oct 2024 10:10:30 -0700 Subject: [PATCH 30/33] resolve comments --- tests/py/dynamo/conversion/harness.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 943fb9d1b1..cf0da5ca1c 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -20,6 +20,7 @@ # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter +from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes from torch_tensorrt.dynamo.lowering import ( get_decompositions, post_lowering, From 594ca28b9c4294210a7c7e85e6db646064ab7eaa Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 30 Oct 2024 16:32:24 -0700 Subject: [PATCH 31/33] resolve comments --- tests/py/dynamo/conversion/harness.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index cf0da5ca1c..cdb62c52b8 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -366,10 +366,6 @@ def generate_graph( tuple(torch_export_inputs), dynamic_shapes=torch_export_dynamic_shapes, ) - exported_program = pre_export_lowering(exported_program, settings) - exported_program = exported_program.run_decompositions( - get_decompositions(False) - ) fx_module = exported_program.module() else: fx_module = torch.fx.symbolic_trace(mod) From 56d034b00eb876065a48da48b4fd4682b7b13da7 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 31 Oct 2024 08:03:43 -0700 Subject: [PATCH 32/33] resolve comments --- tests/py/dynamo/conversion/harness.py | 5 +++++ tests/py/dynamo/conversion/test_convolution_aten.py | 6 ++++++ tests/py/dynamo/conversion/test_deconvolution_aten.py | 6 ++++++ tests/py/dynamo/conversion/test_pool_aten.py | 5 ++++- 4 files changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index cdb62c52b8..61f891267e 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -366,6 +366,11 @@ def generate_graph( tuple(torch_export_inputs), dynamic_shapes=torch_export_dynamic_shapes, ) + if enable_passes: + exported_program = pre_export_lowering(exported_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) fx_module = exported_program.module() else: fx_module = torch.fx.symbolic_trace(mod) diff --git a/tests/py/dynamo/conversion/test_convolution_aten.py b/tests/py/dynamo/conversion/test_convolution_aten.py index 95f4de92b5..5291196abf 100644 --- a/tests/py/dynamo/conversion/test_convolution_aten.py +++ b/tests/py/dynamo/conversion/test_convolution_aten.py @@ -42,6 +42,7 @@ def forward(self, x): TestModule(), inputs, use_dynamo_tracer=True, + enable_passes=True, ) def test_conv1d_with_dynamic_shape( @@ -75,6 +76,7 @@ def forward(self, x): TestModule(), input_specs, use_dynamo_tracer=True, + enable_passes=True, ) @parameterized.expand( @@ -119,6 +121,7 @@ def forward(self, x): TestModule(), inputs, use_dynamo_tracer=True, + enable_passes=True, ) # Testing with (-1, -1, -1, -1) results into Error: @@ -144,6 +147,7 @@ def forward(self, x): TestModule(), input_specs, use_dynamo_tracer=True, + enable_passes=True, ) @parameterized.expand( @@ -182,6 +186,7 @@ def forward(self, x): TestModule(), inputs, use_dynamo_tracer=True, + enable_passes=True, ) # Testing with (-1, -1, -1, -1, -1) results into Error: @@ -207,6 +212,7 @@ def forward(self, x): TestModule(), input_specs, use_dynamo_tracer=True, + enable_passes=True, ) diff --git a/tests/py/dynamo/conversion/test_deconvolution_aten.py b/tests/py/dynamo/conversion/test_deconvolution_aten.py index 307275dba1..d6cbc0579f 100644 --- a/tests/py/dynamo/conversion/test_deconvolution_aten.py +++ b/tests/py/dynamo/conversion/test_deconvolution_aten.py @@ -49,6 +49,7 @@ def forward(self, x): TestModule(), inputs, use_dynamo_tracer=True, + enable_passes=True, ) def test_deconv1d_with_dynamic_shape( @@ -89,6 +90,7 @@ def forward(self, x): TestModule(), input_specs, use_dynamo_tracer=True, + enable_passes=True, ) @parameterized.expand( @@ -133,6 +135,7 @@ def forward(self, x): TestModule(), inputs, use_dynamo_tracer=True, + enable_passes=True, ) # Testing with (-1, -1, -1, -1) results into Error: @@ -158,6 +161,7 @@ def forward(self, x): TestModule(), input_specs, use_dynamo_tracer=True, + enable_passes=True, ) @parameterized.expand( @@ -202,6 +206,7 @@ def forward(self, x): TestModule(), inputs, use_dynamo_tracer=True, + enable_passes=True, ) # Testing with (-1, -1, -1, -1, -1) results into Error: @@ -227,6 +232,7 @@ def forward(self, x): TestModule(), input_specs, use_dynamo_tracer=True, + enable_passes=True, ) diff --git a/tests/py/dynamo/conversion/test_pool_aten.py b/tests/py/dynamo/conversion/test_pool_aten.py index 29fdf30480..f2e3261648 100644 --- a/tests/py/dynamo/conversion/test_pool_aten.py +++ b/tests/py/dynamo/conversion/test_pool_aten.py @@ -36,6 +36,7 @@ def forward(self, x): TestModule(), inputs, use_dynamo_tracer=True, + enable_passes=True, ) @parameterized.expand( @@ -146,7 +147,9 @@ def forward(self, x): ), ] - self.run_test_with_dynamic_shape(pool1d(), input_specs, use_dynamo_tracer=True) + self.run_test_with_dynamic_shape( + pool1d(), input_specs, use_dynamo_tracer=True, enable_passes=True + ) @parameterized.expand( [ From 8be5d5f218db1f24643b1f57c5566080291ad5cf Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 5 Nov 2024 12:35:57 -0800 Subject: [PATCH 33/33] resolve comments --- tests/py/dynamo/conversion/test_attention.py | 280 ------------------- 1 file changed, 280 deletions(-) diff --git a/tests/py/dynamo/conversion/test_attention.py b/tests/py/dynamo/conversion/test_attention.py index 984a5a3093..c174c15a02 100644 --- a/tests/py/dynamo/conversion/test_attention.py +++ b/tests/py/dynamo/conversion/test_attention.py @@ -34,286 +34,6 @@ def forward(self, query, key, value): enable_passes=True, ) - @parameterized.expand( - [ - ( - "4d-2d", - (4, 2, 16, 32), - (6, 3, 32, 64), - (32, 8, 64, 128), - (4, 32), - (4, 64), - (16, 128), - ), - ( - "4d-3d", - (2, 2, 2, 2), - (3, 3, 3, 4), - (3, 4, 4, 5), - (2, 3, 2), - (3, 3, 4), - (4, 5, 5), - ), - ( - "4d-4d", - (4, 2, 12, 4), - (6, 3, 16, 8), - (32, 8, 18, 16), - (4, 2, 4, 16), - (6, 3, 8, 32), - (32, 8, 12, 64), - ), - ] - ) - @unittest.skip("need to change to custom dynamic shapes") - def test_sdpa_no_causal_dynamic_shape_with_scale( - self, - _, - query_min_shape, - query_opt_shape, - query_max_shape, - key_min_shape, - key_opt_shape, - key_max_shape, - ): - class SDPA(nn.Module): - def forward(self, query, key, value): - return torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - None, - 0.0, - is_causal=False, - scale=-0.5, - ) - - inputs = [ - # query - Input( - dtype=torch.float32, - min_shape=query_min_shape, - opt_shape=query_opt_shape, - max_shape=query_max_shape, - ), - # key - Input( - dtype=torch.float32, - min_shape=key_min_shape, - opt_shape=key_opt_shape, - max_shape=key_max_shape, - ), - # value - Input( - dtype=torch.float32, - min_shape=key_min_shape, - opt_shape=key_opt_shape, - max_shape=key_max_shape, - ), - ] - dyn_dim_0 = Dim("dyn_dim_0", min=4, max=32) - dyn_dim_1 = Dim("dyn_dim_1", min=2, max=8) - - q_dyn_dim_2 = Dim("q_dyn_dim_2", min=12, max=18) - q_dyn_dim_3 = Dim("q_dyn_dim_3", min=4, max=16) - - k_dyn_dim_2 = Dim("k_dyn_dim_2", min=4, max=12) - k_dyn_dim_3 = 4 * q_dyn_dim_3 # Dim("k_dyn_dim_3", min=16, max=64) - - torch_export_dynamic_shapes = {} - torch_export_dynamic_shapes["query"] = { - 0: dyn_dim_0, - 1: dyn_dim_1, - 2: q_dyn_dim_2, - 3: q_dyn_dim_3, - } - torch_export_dynamic_shapes["key"] = { - 0: dyn_dim_0, - 1: dyn_dim_1, - 2: k_dyn_dim_2, - 3: k_dyn_dim_3, - } - torch_export_dynamic_shapes["value"] = { - 0: dyn_dim_0, - 1: dyn_dim_1, - 2: k_dyn_dim_2, - 3: k_dyn_dim_3, - } - self.run_test_with_dynamic_shape( - SDPA(), - inputs, - torch_export_dynamic_shapes=torch_export_dynamic_shapes, - enable_passes=True, - ) - - @parameterized.expand( - [ - ( - "4d-2d", - (4, 2, 128, 16), - (6, 3, 128, 32), - (32, 8, 128, 64), - (4, 16), - (4, 32), - (16, 64), - ), - ( - "4d-4d", - (4, 2, 12, 16), - (6, 3, 16, 32), - (32, 8, 18, 64), - (4, 2, 4, 16), - (6, 3, 8, 32), - (32, 8, 12, 64), - ), - ] - ) - @unittest.skip("need to change to custom dynamic shapes") - def test_sdpa_no_causal_no_scale_dynamic_shape( - self, - _, - query_min_shape, - query_opt_shape, - query_max_shape, - key_min_shape, - key_opt_shape, - key_max_shape, - ): - class SDPA(nn.Module): - def forward(self, query, key, value): - return torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - None, - 0.0, - is_causal=False, - scale=None, - ) - - inputs = [ - # query - Input( - dtype=torch.float32, - min_shape=query_min_shape, - opt_shape=query_opt_shape, - max_shape=query_max_shape, - ), - # key - Input( - dtype=torch.float32, - min_shape=key_min_shape, - opt_shape=key_opt_shape, - max_shape=key_max_shape, - ), - # value - Input( - dtype=torch.float32, - min_shape=key_min_shape, - opt_shape=key_opt_shape, - max_shape=key_max_shape, - ), - ] - - self.run_test_with_dynamic_shape(SDPA(), inputs) - - @parameterized.expand( - [ - ( - "4d-2d", - (2, 2, 3, 2), - (3, 3, 4, 2), - (4, 4, 5, 3), - (2, 2), - (3, 2), - (4, 3), - None, - ), - ( - "4d-3d", - (4, 2, 2, 16), - (6, 3, 3, 32), - (32, 4, 5, 64), - (2, 2, 16), - (3, 3, 32), - (4, 4, 64), - 0.1, - ), - ( - "4d-4d", - (4, 2, 2, 4), - (6, 3, 3, 8), - (32, 8, 6, 16), - (4, 2, 3, 4), - (6, 3, 4, 8), - (32, 8, 4, 16), - 0.01, - ), - ] - ) - @unittest.skip("need to change to custom dynamic shapes") - def test_sdpa_causal_dynamic_shape( - self, - _, - query_min_shape, - query_opt_shape, - query_max_shape, - key_min_shape, - key_opt_shape, - key_max_shape, - scale, - ): - class SDPA(nn.Module): - def forward(self, query, key, value): - return torch.nn.functional.scaled_dot_product_attention( - query, key, value, None, 0.0, True, scale=scale - ) - - inputs = [ - # query - Input( - dtype=torch.float32, - min_shape=query_min_shape, - opt_shape=query_opt_shape, - max_shape=query_max_shape, - ), - # key - Input( - dtype=torch.float32, - min_shape=key_min_shape, - opt_shape=key_opt_shape, - max_shape=key_max_shape, - ), - # value - Input( - dtype=torch.float32, - min_shape=key_min_shape, - opt_shape=key_opt_shape, - max_shape=key_max_shape, - ), - ] - - self.run_test_with_dynamic_shape(SDPA(), inputs) - - @parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))]) - # it is already added in the test_composite_aten_op.py as integration test - @unittest.skip( - "skip torch.nn.functional.scaled_dot_product_attention converter test" - ) - def test_sdpa_causal(self, query_shape, key_shape): - class SDPA(nn.Module): - def forward(self, query, key, value): - return torch.nn.functional.scaled_dot_product_attention( - query, key, value, None, 0.0, True, scale=None - ) - - inputs = [] - query = torch.randn(query_shape, dtype=torch.float16) - key = torch.rand(key_shape, dtype=torch.float16) - value = torch.rand(key_shape, dtype=torch.float16) - inputs.extend([query, key, value]) - self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16) - @unittest.skipIf( torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8,