diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index e648fb1e53..61f891267e 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -14,10 +14,13 @@ from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype from torch_tensorrt.dynamo import _defaults +from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo._tracer import get_dynamic_shapes_args # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter +from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes from torch_tensorrt.dynamo.lowering import ( get_decompositions, post_lowering, @@ -29,6 +32,77 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) +# this method is only used in our converter test to infer the module output dtypes via dummy inference +# which is due to fx.symbolic_trace does not have the meta['val'] info in the node +# TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace +def infer_module_output_dtypes_for_test( + module: torch.fx.GraphModule, + inputs: Sequence[Input], + device: Device, + kwarg_inputs: Optional[dict[str, Any]] = None, + truncate_double: bool = False, +) -> List[dtype]: + """ + This function performs model inference to determine the output dtypes + and truncates them accordingly. inputs can be either arg_inputs or flattened input list. + If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input. + """ + # TODO: We can also determine output dtypes from the module.graph based on node metadata. + # However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata, + # so we stick to the model inference approach currently. + with unset_fake_temporarily(): + # Get the device on which the model exists + # For large models, this can be done on CPU to save GPU memory allocation for TRT. + device = get_model_device(module) + torch_inputs = get_torch_inputs(inputs, device) + if kwarg_inputs is None: + kwarg_inputs = {} + torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) + module_outputs = module(*torch_inputs, **torch_kwarg_inputs) + if not isinstance(module_outputs, (list, tuple)): + module_outputs = [module_outputs] + + # Int64 outputs can sometimes be generated from within other operators + # such as aten.sum - such outputs can be truncated + output_dtypes = [] + for output in module_outputs: + output_ = output + # We don't need to check if output is nested here because the input module will be flattened + if not isinstance(output, torch.Tensor): + if isinstance(output, str): + raise ValueError( + f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)" + ) + else: + output_ = torch.tensor(output) + + if truncate_double and output_.dtype == dtype.float64: + output_dtypes.append(dtype.float32) + else: + output_dtypes.append(dtype._from(output_.dtype)) + + return output_dtypes + + +# this is to enable dynamo tracer as True in the converter test files batch by batch +def get_use_dynamo_tracer(use_dynamo_tracer: Any) -> bool: + # if in our converter tests we specifically set use_dynamo_tracer field, honor it + if use_dynamo_tracer is not None and isinstance(use_dynamo_tracer, bool): + return use_dynamo_tracer + # if in our converter tests, we did not specify use_dynamo_tracer field + import inspect + import os + import re + + filename = os.path.basename(inspect.stack()[2].filename) + # enable converter test files which starts with test_a*.py to use dynamo tracer + pattern = re.compile("^test_([a])+") + if pattern.match(filename): + return True + else: + return False + + # this method is only used in our converter test to infer the module output dtypes via dummy inference # which is due to fx.symbolic_trace does not have the meta['val'] info in the node # TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace @@ -277,14 +351,26 @@ def generate_graph( enable_passes: bool, propagate_shapes: bool = False, settings: CompilationSettings = CompilationSettings(), + torch_export_dynamic_shapes: Optional[Any] = None, ): mod = mod.eval() if use_dynamo_tracer: - exported_program = torch_tensorrt.dynamo.trace(mod, tuple(original_inputs)) - exported_program = pre_export_lowering(exported_program, settings) - exported_program = exported_program.run_decompositions( - get_decompositions(False) + if torch_export_dynamic_shapes is None: + torch_export_dynamic_shapes = get_dynamic_shapes_args( + mod, original_inputs + ) + device = default_device() + torch_export_inputs = get_torch_inputs(original_inputs, device) + exported_program = torch.export.export( + mod, + tuple(torch_export_inputs), + dynamic_shapes=torch_export_dynamic_shapes, ) + if enable_passes: + exported_program = pre_export_lowering(exported_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) fx_module = exported_program.module() else: fx_module = torch.fx.symbolic_trace(mod) @@ -313,13 +399,15 @@ def run_test( atol=ATOL, precision=dtype.f32, check_dtype=True, - use_dynamo_tracer=False, + use_dynamo_tracer=None, enable_passes=False, propagate_shapes=False, int32_reqd=False, make_refittable=False, ): - + # TODO: lan to remove this and set use_dynamo_traccer to True by default + # once all the converter test files are moved to use_dynamo_tracer + use_dynamo_tracer = get_use_dynamo_tracer(use_dynamo_tracer) # Previous instance of the interpreter auto-casted 64-bit inputs # We replicate this behavior here compilation_settings = CompilationSettings( @@ -366,12 +454,18 @@ def run_test( output_dtypes = None if check_dtype: - output_dtypes = infer_module_output_dtypes_for_test( - mod, - input_specs, - compilation_settings.device, - truncate_double=compilation_settings.truncate_double, - ) + if use_dynamo_tracer: + output_dtypes = infer_module_output_dtypes( + mod, + truncate_double=compilation_settings.truncate_double, + ) + else: + output_dtypes = infer_module_output_dtypes_for_test( + mod, + input_specs, + compilation_settings.device, + truncate_double=compilation_settings.truncate_double, + ) _LOGGER.debug(f"Compilation settings: {compilation_settings}") _LOGGER.debug(f"Inputs: {input_specs}") @@ -441,21 +535,24 @@ def run_test_with_dynamic_shape( rtol=RTOL, atol=ATOL, output_dtypes=None, - use_dynamo_tracer=False, + use_dynamo_tracer=None, enable_passes=False, use_example_tensors=True, pyt_inputs=None, propagate_shapes=False, check_dtype=True, make_refittable=False, + torch_export_dynamic_shapes=None, ): + # TODO: lan to remove this and set use_dynamo_traccer to True by default + # once all the converter test files are moved to use_dynamo_tracer + use_dynamo_tracer = get_use_dynamo_tracer(use_dynamo_tracer) # Previous instance of the interpreter auto-casted 64-bit inputs # We replicate this behavior here compilation_settings = CompilationSettings( truncate_double=True, make_refittable=make_refittable ) - mod = self.generate_graph( mod, input_specs, @@ -463,15 +560,22 @@ def run_test_with_dynamic_shape( enable_passes=enable_passes, propagate_shapes=propagate_shapes, settings=compilation_settings, + torch_export_dynamic_shapes=torch_export_dynamic_shapes, ) if check_dtype: - output_dtypes = infer_module_output_dtypes_for_test( - mod, - input_specs, - compilation_settings.device, - truncate_double=compilation_settings.truncate_double, - ) + if use_dynamo_tracer: + output_dtypes = infer_module_output_dtypes( + mod, + truncate_double=compilation_settings.truncate_double, + ) + else: + output_dtypes = infer_module_output_dtypes_for_test( + mod, + input_specs, + compilation_settings.device, + truncate_double=compilation_settings.truncate_double, + ) interp = TRTInterpreter( mod, diff --git a/tests/py/dynamo/conversion/test_acos_aten.py b/tests/py/dynamo/conversion/test_acos_aten.py index 81b83bcc4a..8e93e0a309 100644 --- a/tests/py/dynamo/conversion/test_acos_aten.py +++ b/tests/py/dynamo/conversion/test_acos_aten.py @@ -66,7 +66,7 @@ def forward(self, input): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, diff --git a/tests/py/dynamo/conversion/test_acosh_aten.py b/tests/py/dynamo/conversion/test_acosh_aten.py index 090756ddfb..dd19c188c5 100644 --- a/tests/py/dynamo/conversion/test_acosh_aten.py +++ b/tests/py/dynamo/conversion/test_acosh_aten.py @@ -58,7 +58,7 @@ def forward(self, input): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, @@ -66,7 +66,7 @@ def forward(self, input): ( "3d_dim_dtype_int32", (1, 1, 1), - (1, 2, 4), + (2, 2, 4), (2, 3, 5), torch.int32, torch.float, diff --git a/tests/py/dynamo/conversion/test_any.py b/tests/py/dynamo/conversion/test_any.py index 1d1fc634ef..75620f7c34 100644 --- a/tests/py/dynamo/conversion/test_any.py +++ b/tests/py/dynamo/conversion/test_any.py @@ -191,7 +191,7 @@ class TestAnyConverterDynamic(DispatchTestCase): ( "3d_dynamic_float", (2, 1, 1), - (2, 2, 1), + (2, 2, 2), (3, 2, 4), torch.float, ), @@ -234,7 +234,7 @@ def forward(self, x): ( "3d_dynamic_dim_float", (2, 1, 1), - (2, 2, 1), + (2, 2, 2), (3, 2, 4), torch.float, 2, @@ -252,7 +252,7 @@ def forward(self, x): ( "3d_dynamic_dim_bool", (2, 1, 1), - (2, 2, 1), + (2, 2, 2), (3, 2, 4), torch.bool, 0, @@ -285,7 +285,7 @@ def forward(self, x): ( "3d_dynamic_dims_float", (2, 1, 1), - (2, 2, 1), + (2, 2, 2), (3, 2, 4), torch.float, [1, 2], diff --git a/tests/py/dynamo/conversion/test_arange_aten.py b/tests/py/dynamo/conversion/test_arange_aten.py index 1e1e9b3cc7..cb3e4c6b51 100644 --- a/tests/py/dynamo/conversion/test_arange_aten.py +++ b/tests/py/dynamo/conversion/test_arange_aten.py @@ -56,6 +56,7 @@ def forward(self, end_tensor): use_example_tensors=False, check_dtype=False, pyt_inputs=[pyt_input], + use_dynamo_tracer=False, ) diff --git a/tests/py/dynamo/conversion/test_asin_aten.py b/tests/py/dynamo/conversion/test_asin_aten.py index 2b8eb84144..f520e5c7a3 100644 --- a/tests/py/dynamo/conversion/test_asin_aten.py +++ b/tests/py/dynamo/conversion/test_asin_aten.py @@ -66,7 +66,7 @@ def forward(self, input): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, diff --git a/tests/py/dynamo/conversion/test_asinh_aten.py b/tests/py/dynamo/conversion/test_asinh_aten.py index c5fdc71883..ca9b7e9126 100644 --- a/tests/py/dynamo/conversion/test_asinh_aten.py +++ b/tests/py/dynamo/conversion/test_asinh_aten.py @@ -58,7 +58,7 @@ def forward(self, input): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, @@ -66,7 +66,7 @@ def forward(self, input): ( "3d_dim_dtype_int32", (1, 1, 1), - (1, 2, 4), + (2, 2, 4), (2, 3, 5), torch.int32, torch.float, diff --git a/tests/py/dynamo/conversion/test_atan2_aten.py b/tests/py/dynamo/conversion/test_atan2_aten.py index b684a960ac..2db650cbaa 100644 --- a/tests/py/dynamo/conversion/test_atan2_aten.py +++ b/tests/py/dynamo/conversion/test_atan2_aten.py @@ -1,3 +1,5 @@ +import unittest + import torch import torch.nn as nn from parameterized import parameterized @@ -141,7 +143,7 @@ def forward(self, lhs_val, rhs_val): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, @@ -149,7 +151,7 @@ def forward(self, lhs_val, rhs_val): ( "3d_dim_dtype_int32", (1, 1, 1), - (1, 2, 4), + (2, 2, 4), (2, 3, 5), torch.int32, torch.float, @@ -182,10 +184,17 @@ def forward(self, lhs_val, rhs_val): ) +# torch.ops.aten.atan2.out will be decomposed/partitioned into core aten ops which torch_tensorrt supported in run_on_acc and +# non supported ops in run_on_gpu in dynamo tracer, it works via torch_tensorrt.dynamo.compile workflow +# but it won't be valid for our converter test framework, so skip it here. +@unittest.skip("skip torch.ops.aten.atan2.out converter test") class TestAtan2OutConverter(DispatchTestCase): @parameterized.expand( [ - ((10,), (5,), torch.float), + # dynamo trace does not allow output to be in a different shape + # raise Unsupported(msg, case_name=case_name) + # torch._dynamo.exc.Unsupported: out variants with resizing on graph inputs + ((5,), (5,), torch.float), ((10,), (10,), torch.float), ] ) @@ -220,7 +229,7 @@ def forward(self, lhs_val, rhs_val, out): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, @@ -255,7 +264,10 @@ def forward(self, lhs_val, rhs_val, out): ), ] self.run_test_with_dynamic_shape( - atan2(), input_specs, output_dtypes=[output_type] + atan2(), + input_specs, + output_dtypes=[output_type], + use_dynamo_tracer=False, ) diff --git a/tests/py/dynamo/conversion/test_atan_aten.py b/tests/py/dynamo/conversion/test_atan_aten.py index dda06404cb..2df58d4f69 100644 --- a/tests/py/dynamo/conversion/test_atan_aten.py +++ b/tests/py/dynamo/conversion/test_atan_aten.py @@ -58,7 +58,7 @@ def forward(self, input): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, @@ -66,7 +66,7 @@ def forward(self, input): ( "3d_dim_dtype_int32", (1, 1, 1), - (1, 2, 4), + (2, 2, 4), (2, 3, 5), torch.int32, torch.float, diff --git a/tests/py/dynamo/conversion/test_atanh_aten.py b/tests/py/dynamo/conversion/test_atanh_aten.py index 3b7ac77541..2438bf8252 100644 --- a/tests/py/dynamo/conversion/test_atanh_aten.py +++ b/tests/py/dynamo/conversion/test_atanh_aten.py @@ -58,7 +58,7 @@ def forward(self, input): ( "3d_dim_dtype_float", (1, 1, 1), - (1, 2, 3), + (2, 2, 3), (3, 3, 3), torch.float, torch.float, @@ -66,7 +66,7 @@ def forward(self, input): ( "3d_dim_dtype_int32", (1, 1, 1), - (1, 2, 4), + (2, 2, 4), (2, 3, 5), torch.int32, torch.float, diff --git a/tests/py/dynamo/conversion/test_attention.py b/tests/py/dynamo/conversion/test_attention.py index 5109a2e2ca..c174c15a02 100644 --- a/tests/py/dynamo/conversion/test_attention.py +++ b/tests/py/dynamo/conversion/test_attention.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn from parameterized import parameterized +from torch.export import Dim from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input @@ -24,248 +25,14 @@ def forward(self, query, key, value): key = torch.rand(key_shape, dtype=torch.float16) value = torch.rand(key_shape, dtype=torch.float16) inputs.extend([query, key, value]) - self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16) - - @parameterized.expand( - [ - ( - "4d-2d", - (4, 2, 16, 32), - (6, 3, 32, 64), - (32, 8, 64, 128), - (4, 32), - (4, 64), - (16, 128), - ), - ( - "4d-3d", - (2, 2, 2, 2), - (3, 3, 3, 4), - (3, 4, 4, 5), - (2, 3, 2), - (3, 3, 4), - (4, 5, 5), - ), - ( - "4d-4d", - (4, 2, 12, 16), - (6, 3, 16, 32), - (32, 8, 18, 64), - (4, 2, 4, 16), - (6, 3, 8, 32), - (32, 8, 12, 64), - ), - ] - ) - def test_sdpa_no_causal_dynamic_shape_with_scale( - self, - _, - query_min_shape, - query_opt_shape, - query_max_shape, - key_min_shape, - key_opt_shape, - key_max_shape, - ): - class SDPA(nn.Module): - def forward(self, query, key, value): - return torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - None, - 0.0, - is_causal=False, - scale=-0.5, - ) - - inputs = [ - # query - Input( - dtype=torch.float32, - min_shape=query_min_shape, - opt_shape=query_opt_shape, - max_shape=query_max_shape, - ), - # key - Input( - dtype=torch.float32, - min_shape=key_min_shape, - opt_shape=key_opt_shape, - max_shape=key_max_shape, - ), - # value - Input( - dtype=torch.float32, - min_shape=key_min_shape, - opt_shape=key_opt_shape, - max_shape=key_max_shape, - ), - ] - - self.run_test_with_dynamic_shape(SDPA(), inputs) - - @parameterized.expand( - [ - ( - "4d-2d", - (4, 2, 128, 16), - (6, 3, 128, 32), - (32, 8, 128, 64), - (4, 16), - (4, 32), - (16, 64), - ), - ( - "4d-4d", - (4, 2, 12, 16), - (6, 3, 16, 32), - (32, 8, 18, 64), - (4, 2, 4, 16), - (6, 3, 8, 32), - (32, 8, 12, 64), - ), - ] - ) - def test_sdpa_no_causal_no_scale_dynamic_shape( - self, - _, - query_min_shape, - query_opt_shape, - query_max_shape, - key_min_shape, - key_opt_shape, - key_max_shape, - ): - class SDPA(nn.Module): - def forward(self, query, key, value): - return torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - None, - 0.0, - is_causal=False, - scale=None, - ) - - inputs = [ - # query - Input( - dtype=torch.float32, - min_shape=query_min_shape, - opt_shape=query_opt_shape, - max_shape=query_max_shape, - ), - # key - Input( - dtype=torch.float32, - min_shape=key_min_shape, - opt_shape=key_opt_shape, - max_shape=key_max_shape, - ), - # value - Input( - dtype=torch.float32, - min_shape=key_min_shape, - opt_shape=key_opt_shape, - max_shape=key_max_shape, - ), - ] - - self.run_test_with_dynamic_shape(SDPA(), inputs) - - @parameterized.expand( - [ - ( - "4d-2d", - (2, 2, 3, 2), - (3, 3, 4, 2), - (4, 4, 5, 3), - (2, 2), - (3, 2), - (4, 3), - None, - ), - ( - "4d-3d", - (4, 2, 2, 16), - (6, 3, 3, 32), - (32, 4, 5, 64), - (2, 2, 16), - (3, 3, 32), - (4, 4, 64), - 0.1, - ), - ( - "4d-4d", - (4, 2, 2, 4), - (6, 3, 3, 8), - (32, 8, 6, 16), - (4, 2, 3, 4), - (6, 3, 4, 8), - (32, 8, 4, 16), - 0.01, - ), - ] - ) - def test_sdpa_causal_dynamic_shape( - self, - _, - query_min_shape, - query_opt_shape, - query_max_shape, - key_min_shape, - key_opt_shape, - key_max_shape, - scale, - ): - class SDPA(nn.Module): - def forward(self, query, key, value): - return torch.nn.functional.scaled_dot_product_attention( - query, key, value, None, 0.0, True, scale=scale - ) - - inputs = [ - # query - Input( - dtype=torch.float32, - min_shape=query_min_shape, - opt_shape=query_opt_shape, - max_shape=query_max_shape, - ), - # key - Input( - dtype=torch.float32, - min_shape=key_min_shape, - opt_shape=key_opt_shape, - max_shape=key_max_shape, - ), - # value - Input( - dtype=torch.float32, - min_shape=key_min_shape, - opt_shape=key_opt_shape, - max_shape=key_max_shape, - ), - ] - - self.run_test_with_dynamic_shape(SDPA(), inputs) - - @parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))]) - def test_sdpa_causal(self, query_shape, key_shape): - class SDPA(nn.Module): - def forward(self, query, key, value): - return torch.nn.functional.scaled_dot_product_attention( - query, key, value, None, 0.0, True, scale=None - ) - - inputs = [] - query = torch.randn(query_shape, dtype=torch.float16) - key = torch.rand(key_shape, dtype=torch.float16) - value = torch.rand(key_shape, dtype=torch.float16) - inputs.extend([query, key, value]) - self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16) + self.run_test( + SDPA(), + inputs, + rtol=1e-2, + atol=1e-2, + precision=torch.float16, + enable_passes=True, + ) @unittest.skipIf( diff --git a/tests/py/dynamo/conversion/test_composite_aten_op.py b/tests/py/dynamo/conversion/test_composite_aten_op.py new file mode 100644 index 0000000000..b2f13fb9c6 --- /dev/null +++ b/tests/py/dynamo/conversion/test_composite_aten_op.py @@ -0,0 +1,193 @@ +import unittest + +import pytest +import torch +import torch.nn as nn +import torch_tensorrt as torchtrt +from parameterized import parameterized +from torch.export import Dim +from torch_tensorrt import Input +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +assertions = unittest.TestCase() + + +@pytest.mark.unit +@parameterized.expand( + [ + ((5,), (5,)), + ( + ( + 2, + 3, + ), + ( + 2, + 3, + ), + ), + ] +) +def test_atan2_out_static_shape(input_shape, out_shape): + class atan2(torch.nn.Module): + def forward(self, lhs_val, rhs_val, out): + return torch.ops.aten.atan2.out(lhs_val, rhs_val, out=out) + + model = atan2().eval().cuda() + inputs = ( + torch.randn(input_shape).cuda(), + torch.randn(input_shape).cuda(), + torch.randn(out_shape).cuda(), + ) + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + } + + trt_model = torchtrt.compile(model, **compile_spec) + py_outputs = model(*inputs) + trt_outputs = trt_model(*inputs) + cos_sim = cosine_similarity(py_outputs, trt_outputs) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_atan2_out_static_shape model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +@parameterized.expand( + [ + ( + ( + 1, + 2, + ), + (2, 3), + (2, 4), + ), + ] +) +def test_atan2_out_dynamic_shape(min_shape, opt_shape, max_shape): + class atan2(torch.nn.Module): + def forward(self, lhs_val, rhs_val, out): + return torch.ops.aten.atan2.out(lhs_val, rhs_val, out=out) + + model = atan2().eval().cuda() + input_spec = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + + compile_spec = { + "inputs": input_spec, + "ir": "dynamo", + "min_block_size": 1, + } + + trt_model = torchtrt.compile(model, **compile_spec) + inputs = ( + torch.randn(max_shape).cuda(), + torch.randn(max_shape).cuda(), + torch.randn(max_shape).cuda(), + ) + py_outputs = model(*inputs) + trt_outputs = trt_model(*inputs) + cos_sim = cosine_similarity(py_outputs, trt_outputs) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_atan2_out_dynamic_shape model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@parameterized.expand( + [ + ((32, 8, 128, 64), (32, 8, 128, 64), True, None), + ] +) +def test_sdpa_static_shape(query_shape, key_shape, is_causal, scale): + class SDPA(nn.Module): + def forward(self, query, key, value): + return torch.nn.functional.scaled_dot_product_attention( + query, key, value, None, 0.0, is_causal=is_causal, scale=scale + ) + + model = SDPA().eval().cuda() + + query = torch.randn(query_shape, dtype=torch.float16).cuda() + key = torch.randn(key_shape, dtype=torch.float16).cuda() + value = torch.randn(key_shape, dtype=torch.float16).cuda() + inputs = (query, key, value) + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + } + + trt_model = torchtrt.compile(model, **compile_spec) + py_outputs = model(*inputs) + trt_outputs = trt_model(*inputs) + cos_sim = cosine_similarity(py_outputs, trt_outputs) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_sdpa_static_shape model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@parameterized.expand( + [ + (True, None), + (True, 0.1), + (False, None), + ] +) +def test_sdpa_dynamic_shape(is_causal, scale): + class SDPA(nn.Module): + def forward(self, query, key, value): + return torch.nn.functional.scaled_dot_product_attention( + query, key, value, None, 0.0, is_causal=is_causal, scale=scale + ) + + model = SDPA().eval().cuda() + + # N: batch_size + dyn_N = Dim("dyn_N", min=2, max=4) + + # query tensor shape (N, ..., Hq, L, E) + query = torch.randn((3, 3, 4, 64), dtype=torch.float16).cuda() + # key tensor shape (N,...,H, S, E) + key = torch.randn((3, 3, 4, 64), dtype=torch.float16).cuda() + # value tensor shape (N, ..., H, S, Ev) + value = torch.randn((3, 3, 4, 64), dtype=torch.float16).cuda() + + dynamic_shapes = {"query": {0: dyn_N}, "key": {0: dyn_N}, "value": {0: dyn_N}} + inputs = (query, key, value) + + exp_program = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + } + trt_model = torchtrt.dynamo.compile(exp_program, **compile_spec) + py_outputs = model(*inputs) + trt_outputs = trt_model(*inputs) + cos_sim = cosine_similarity(py_outputs, trt_outputs) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_sdpa_dynamic_shape model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) diff --git a/tests/py/dynamo/conversion/test_convolution_aten.py b/tests/py/dynamo/conversion/test_convolution_aten.py index 95f4de92b5..5291196abf 100644 --- a/tests/py/dynamo/conversion/test_convolution_aten.py +++ b/tests/py/dynamo/conversion/test_convolution_aten.py @@ -42,6 +42,7 @@ def forward(self, x): TestModule(), inputs, use_dynamo_tracer=True, + enable_passes=True, ) def test_conv1d_with_dynamic_shape( @@ -75,6 +76,7 @@ def forward(self, x): TestModule(), input_specs, use_dynamo_tracer=True, + enable_passes=True, ) @parameterized.expand( @@ -119,6 +121,7 @@ def forward(self, x): TestModule(), inputs, use_dynamo_tracer=True, + enable_passes=True, ) # Testing with (-1, -1, -1, -1) results into Error: @@ -144,6 +147,7 @@ def forward(self, x): TestModule(), input_specs, use_dynamo_tracer=True, + enable_passes=True, ) @parameterized.expand( @@ -182,6 +186,7 @@ def forward(self, x): TestModule(), inputs, use_dynamo_tracer=True, + enable_passes=True, ) # Testing with (-1, -1, -1, -1, -1) results into Error: @@ -207,6 +212,7 @@ def forward(self, x): TestModule(), input_specs, use_dynamo_tracer=True, + enable_passes=True, ) diff --git a/tests/py/dynamo/conversion/test_deconvolution_aten.py b/tests/py/dynamo/conversion/test_deconvolution_aten.py index 307275dba1..d6cbc0579f 100644 --- a/tests/py/dynamo/conversion/test_deconvolution_aten.py +++ b/tests/py/dynamo/conversion/test_deconvolution_aten.py @@ -49,6 +49,7 @@ def forward(self, x): TestModule(), inputs, use_dynamo_tracer=True, + enable_passes=True, ) def test_deconv1d_with_dynamic_shape( @@ -89,6 +90,7 @@ def forward(self, x): TestModule(), input_specs, use_dynamo_tracer=True, + enable_passes=True, ) @parameterized.expand( @@ -133,6 +135,7 @@ def forward(self, x): TestModule(), inputs, use_dynamo_tracer=True, + enable_passes=True, ) # Testing with (-1, -1, -1, -1) results into Error: @@ -158,6 +161,7 @@ def forward(self, x): TestModule(), input_specs, use_dynamo_tracer=True, + enable_passes=True, ) @parameterized.expand( @@ -202,6 +206,7 @@ def forward(self, x): TestModule(), inputs, use_dynamo_tracer=True, + enable_passes=True, ) # Testing with (-1, -1, -1, -1, -1) results into Error: @@ -227,6 +232,7 @@ def forward(self, x): TestModule(), input_specs, use_dynamo_tracer=True, + enable_passes=True, ) diff --git a/tests/py/dynamo/conversion/test_pool_aten.py b/tests/py/dynamo/conversion/test_pool_aten.py index 29fdf30480..f2e3261648 100644 --- a/tests/py/dynamo/conversion/test_pool_aten.py +++ b/tests/py/dynamo/conversion/test_pool_aten.py @@ -36,6 +36,7 @@ def forward(self, x): TestModule(), inputs, use_dynamo_tracer=True, + enable_passes=True, ) @parameterized.expand( @@ -146,7 +147,9 @@ def forward(self, x): ), ] - self.run_test_with_dynamic_shape(pool1d(), input_specs, use_dynamo_tracer=True) + self.run_test_with_dynamic_shape( + pool1d(), input_specs, use_dynamo_tracer=True, enable_passes=True + ) @parameterized.expand( [