Skip to content

Commit

Permalink
switch from fx.symbolic_trace to dynamo_trace for converter test part…
Browse files Browse the repository at this point in the history
…-1 (#3261)
  • Loading branch information
lanluo-nvidia authored Nov 5, 2024
1 parent 3ecd5aa commit e43833d
Show file tree
Hide file tree
Showing 15 changed files with 374 additions and 282 deletions.
144 changes: 124 additions & 20 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo._defaults import default_device
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo._tracer import get_dynamic_shapes_args

# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry
from torch_tensorrt.dynamo.conversion import TRTInterpreter
from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes
from torch_tensorrt.dynamo.lowering import (
get_decompositions,
post_lowering,
Expand All @@ -29,6 +32,77 @@
_LOGGER: logging.Logger = logging.getLogger(__name__)


# this method is only used in our converter test to infer the module output dtypes via dummy inference
# which is due to fx.symbolic_trace does not have the meta['val'] info in the node
# TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace
def infer_module_output_dtypes_for_test(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
device: Device,
kwarg_inputs: Optional[dict[str, Any]] = None,
truncate_double: bool = False,
) -> List[dtype]:
"""
This function performs model inference to determine the output dtypes
and truncates them accordingly. inputs can be either arg_inputs or flattened input list.
If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input.
"""
# TODO: We can also determine output dtypes from the module.graph based on node metadata.
# However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata,
# so we stick to the model inference approach currently.
with unset_fake_temporarily():
# Get the device on which the model exists
# For large models, this can be done on CPU to save GPU memory allocation for TRT.
device = get_model_device(module)
torch_inputs = get_torch_inputs(inputs, device)
if kwarg_inputs is None:
kwarg_inputs = {}
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
module_outputs = module(*torch_inputs, **torch_kwarg_inputs)
if not isinstance(module_outputs, (list, tuple)):
module_outputs = [module_outputs]

# Int64 outputs can sometimes be generated from within other operators
# such as aten.sum - such outputs can be truncated
output_dtypes = []
for output in module_outputs:
output_ = output
# We don't need to check if output is nested here because the input module will be flattened
if not isinstance(output, torch.Tensor):
if isinstance(output, str):
raise ValueError(
f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)"
)
else:
output_ = torch.tensor(output)

if truncate_double and output_.dtype == dtype.float64:
output_dtypes.append(dtype.float32)
else:
output_dtypes.append(dtype._from(output_.dtype))

return output_dtypes


# this is to enable dynamo tracer as True in the converter test files batch by batch
def get_use_dynamo_tracer(use_dynamo_tracer: Any) -> bool:
# if in our converter tests we specifically set use_dynamo_tracer field, honor it
if use_dynamo_tracer is not None and isinstance(use_dynamo_tracer, bool):
return use_dynamo_tracer
# if in our converter tests, we did not specify use_dynamo_tracer field
import inspect
import os
import re

filename = os.path.basename(inspect.stack()[2].filename)
# enable converter test files which starts with test_a*.py to use dynamo tracer
pattern = re.compile("^test_([a])+")
if pattern.match(filename):
return True
else:
return False


# this method is only used in our converter test to infer the module output dtypes via dummy inference
# which is due to fx.symbolic_trace does not have the meta['val'] info in the node
# TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace
Expand Down Expand Up @@ -277,14 +351,26 @@ def generate_graph(
enable_passes: bool,
propagate_shapes: bool = False,
settings: CompilationSettings = CompilationSettings(),
torch_export_dynamic_shapes: Optional[Any] = None,
):
mod = mod.eval()
if use_dynamo_tracer:
exported_program = torch_tensorrt.dynamo.trace(mod, tuple(original_inputs))
exported_program = pre_export_lowering(exported_program, settings)
exported_program = exported_program.run_decompositions(
get_decompositions(False)
if torch_export_dynamic_shapes is None:
torch_export_dynamic_shapes = get_dynamic_shapes_args(
mod, original_inputs
)
device = default_device()
torch_export_inputs = get_torch_inputs(original_inputs, device)
exported_program = torch.export.export(
mod,
tuple(torch_export_inputs),
dynamic_shapes=torch_export_dynamic_shapes,
)
if enable_passes:
exported_program = pre_export_lowering(exported_program, settings)
exported_program = exported_program.run_decompositions(
get_decompositions(False)
)
fx_module = exported_program.module()
else:
fx_module = torch.fx.symbolic_trace(mod)
Expand Down Expand Up @@ -313,13 +399,15 @@ def run_test(
atol=ATOL,
precision=dtype.f32,
check_dtype=True,
use_dynamo_tracer=False,
use_dynamo_tracer=None,
enable_passes=False,
propagate_shapes=False,
int32_reqd=False,
make_refittable=False,
):

# TODO: lan to remove this and set use_dynamo_traccer to True by default
# once all the converter test files are moved to use_dynamo_tracer
use_dynamo_tracer = get_use_dynamo_tracer(use_dynamo_tracer)
# Previous instance of the interpreter auto-casted 64-bit inputs
# We replicate this behavior here
compilation_settings = CompilationSettings(
Expand Down Expand Up @@ -366,12 +454,18 @@ def run_test(

output_dtypes = None
if check_dtype:
output_dtypes = infer_module_output_dtypes_for_test(
mod,
input_specs,
compilation_settings.device,
truncate_double=compilation_settings.truncate_double,
)
if use_dynamo_tracer:
output_dtypes = infer_module_output_dtypes(
mod,
truncate_double=compilation_settings.truncate_double,
)
else:
output_dtypes = infer_module_output_dtypes_for_test(
mod,
input_specs,
compilation_settings.device,
truncate_double=compilation_settings.truncate_double,
)

_LOGGER.debug(f"Compilation settings: {compilation_settings}")
_LOGGER.debug(f"Inputs: {input_specs}")
Expand Down Expand Up @@ -441,37 +535,47 @@ def run_test_with_dynamic_shape(
rtol=RTOL,
atol=ATOL,
output_dtypes=None,
use_dynamo_tracer=False,
use_dynamo_tracer=None,
enable_passes=False,
use_example_tensors=True,
pyt_inputs=None,
propagate_shapes=False,
check_dtype=True,
make_refittable=False,
torch_export_dynamic_shapes=None,
):
# TODO: lan to remove this and set use_dynamo_traccer to True by default
# once all the converter test files are moved to use_dynamo_tracer
use_dynamo_tracer = get_use_dynamo_tracer(use_dynamo_tracer)

# Previous instance of the interpreter auto-casted 64-bit inputs
# We replicate this behavior here
compilation_settings = CompilationSettings(
truncate_double=True, make_refittable=make_refittable
)

mod = self.generate_graph(
mod,
input_specs,
use_dynamo_tracer=use_dynamo_tracer,
enable_passes=enable_passes,
propagate_shapes=propagate_shapes,
settings=compilation_settings,
torch_export_dynamic_shapes=torch_export_dynamic_shapes,
)

if check_dtype:
output_dtypes = infer_module_output_dtypes_for_test(
mod,
input_specs,
compilation_settings.device,
truncate_double=compilation_settings.truncate_double,
)
if use_dynamo_tracer:
output_dtypes = infer_module_output_dtypes(
mod,
truncate_double=compilation_settings.truncate_double,
)
else:
output_dtypes = infer_module_output_dtypes_for_test(
mod,
input_specs,
compilation_settings.device,
truncate_double=compilation_settings.truncate_double,
)

interp = TRTInterpreter(
mod,
Expand Down
2 changes: 1 addition & 1 deletion tests/py/dynamo/conversion/test_acos_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(self, input):
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(2, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
Expand Down
4 changes: 2 additions & 2 deletions tests/py/dynamo/conversion/test_acosh_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ def forward(self, input):
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(2, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
(
"3d_dim_dtype_int32",
(1, 1, 1),
(1, 2, 4),
(2, 2, 4),
(2, 3, 5),
torch.int32,
torch.float,
Expand Down
8 changes: 4 additions & 4 deletions tests/py/dynamo/conversion/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class TestAnyConverterDynamic(DispatchTestCase):
(
"3d_dynamic_float",
(2, 1, 1),
(2, 2, 1),
(2, 2, 2),
(3, 2, 4),
torch.float,
),
Expand Down Expand Up @@ -234,7 +234,7 @@ def forward(self, x):
(
"3d_dynamic_dim_float",
(2, 1, 1),
(2, 2, 1),
(2, 2, 2),
(3, 2, 4),
torch.float,
2,
Expand All @@ -252,7 +252,7 @@ def forward(self, x):
(
"3d_dynamic_dim_bool",
(2, 1, 1),
(2, 2, 1),
(2, 2, 2),
(3, 2, 4),
torch.bool,
0,
Expand Down Expand Up @@ -285,7 +285,7 @@ def forward(self, x):
(
"3d_dynamic_dims_float",
(2, 1, 1),
(2, 2, 1),
(2, 2, 2),
(3, 2, 4),
torch.float,
[1, 2],
Expand Down
1 change: 1 addition & 0 deletions tests/py/dynamo/conversion/test_arange_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def forward(self, end_tensor):
use_example_tensors=False,
check_dtype=False,
pyt_inputs=[pyt_input],
use_dynamo_tracer=False,
)


Expand Down
2 changes: 1 addition & 1 deletion tests/py/dynamo/conversion/test_asin_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(self, input):
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(2, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
Expand Down
4 changes: 2 additions & 2 deletions tests/py/dynamo/conversion/test_asinh_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ def forward(self, input):
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(2, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
(
"3d_dim_dtype_int32",
(1, 1, 1),
(1, 2, 4),
(2, 2, 4),
(2, 3, 5),
torch.int32,
torch.float,
Expand Down
22 changes: 17 additions & 5 deletions tests/py/dynamo/conversion/test_atan2_aten.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import unittest

import torch
import torch.nn as nn
from parameterized import parameterized
Expand Down Expand Up @@ -141,15 +143,15 @@ def forward(self, lhs_val, rhs_val):
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(2, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
(
"3d_dim_dtype_int32",
(1, 1, 1),
(1, 2, 4),
(2, 2, 4),
(2, 3, 5),
torch.int32,
torch.float,
Expand Down Expand Up @@ -182,10 +184,17 @@ def forward(self, lhs_val, rhs_val):
)


# torch.ops.aten.atan2.out will be decomposed/partitioned into core aten ops which torch_tensorrt supported in run_on_acc and
# non supported ops in run_on_gpu in dynamo tracer, it works via torch_tensorrt.dynamo.compile workflow
# but it won't be valid for our converter test framework, so skip it here.
@unittest.skip("skip torch.ops.aten.atan2.out converter test")
class TestAtan2OutConverter(DispatchTestCase):
@parameterized.expand(
[
((10,), (5,), torch.float),
# dynamo trace does not allow output to be in a different shape
# raise Unsupported(msg, case_name=case_name)
# torch._dynamo.exc.Unsupported: out variants with resizing on graph inputs
((5,), (5,), torch.float),
((10,), (10,), torch.float),
]
)
Expand Down Expand Up @@ -220,7 +229,7 @@ def forward(self, lhs_val, rhs_val, out):
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(2, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
Expand Down Expand Up @@ -255,7 +264,10 @@ def forward(self, lhs_val, rhs_val, out):
),
]
self.run_test_with_dynamic_shape(
atan2(), input_specs, output_dtypes=[output_type]
atan2(),
input_specs,
output_dtypes=[output_type],
use_dynamo_tracer=False,
)


Expand Down
Loading

0 comments on commit e43833d

Please sign in to comment.