Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch from fx.symbolic_trace to dynamo_trace for converter test part-1 #3261

Merged
merged 43 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
458a4d1
skip run_shape_analysis
lanluo-nvidia Oct 6, 2024
2f408f9
test
lanluo-nvidia Oct 6, 2024
1c5e86c
test
lanluo-nvidia Oct 6, 2024
ba487dc
test
lanluo-nvidia Oct 6, 2024
99d2274
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 6, 2024
2b43480
test
lanluo-nvidia Oct 6, 2024
b4e02e1
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 11, 2024
3d94f8b
test
lanluo-nvidia Oct 13, 2024
28ba6cc
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 15, 2024
b89cbe0
resolve comments
lanluo-nvidia Oct 15, 2024
2843d37
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 16, 2024
3eb48d7
test
lanluo-nvidia Oct 16, 2024
50eb0d8
replace dummy inference
lanluo-nvidia Oct 20, 2024
95ed602
test
lanluo-nvidia Oct 20, 2024
120f30d
test
lanluo-nvidia Oct 21, 2024
424cbf7
add run_test_with_dynamic_shape change
lanluo-nvidia Oct 21, 2024
2fc9cef
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 21, 2024
ef54cfc
split the PR, add dummy inference for converter test
lanluo-nvidia Oct 21, 2024
14f5d61
test
lanluo-nvidia Oct 22, 2024
7563959
test
lanluo-nvidia Oct 22, 2024
77355f0
test
lanluo-nvidia Oct 22, 2024
891e963
enable converter non dynamic shape tests to use dynamo tracer
lanluo-nvidia Oct 22, 2024
13361fd
add linear lowering meta val
lanluo-nvidia Oct 22, 2024
f0a9fef
add linear_lowering change
lanluo-nvidia Oct 23, 2024
cff64a4
test
lanluo-nvidia Oct 23, 2024
814262f
Merge branch 'lluo/save_remove_inputs' into lluo/switch_to_dynamo_trace
lanluo-nvidia Oct 23, 2024
933abac
test
lanluo-nvidia Oct 23, 2024
8417684
resolve comments
lanluo-nvidia Oct 25, 2024
8676f88
test
lanluo-nvidia Oct 25, 2024
d8e52bf
test
lanluo-nvidia Oct 27, 2024
4d46235
Merge branch 'lluo/save_remove_inputs' into lluo/switch_to_dynamo_trace
lanluo-nvidia Oct 27, 2024
8b3842a
test
lanluo-nvidia Oct 27, 2024
7ddf56f
test
lanluo-nvidia Oct 27, 2024
39e0a49
test
lanluo-nvidia Oct 27, 2024
076f47a
resolve comments
lanluo-nvidia Oct 29, 2024
8250179
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 29, 2024
96e93e4
resolve comments
lanluo-nvidia Oct 29, 2024
7a9659f
Merge branch 'lluo/save_remove_inputs' into lluo/switch_to_dynamo_trace
lanluo-nvidia Oct 29, 2024
cb656bb
Merge branch 'main' into lluo/switch_to_dynamo_trace
lanluo-nvidia Oct 29, 2024
c023714
resolve comments
lanluo-nvidia Oct 30, 2024
594ca28
resolve comments
lanluo-nvidia Oct 30, 2024
56d034b
resolve comments
lanluo-nvidia Oct 31, 2024
8be5d5f
resolve comments
lanluo-nvidia Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 124 additions & 20 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo._defaults import default_device
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo._tracer import get_dynamic_shapes_args

# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry
from torch_tensorrt.dynamo.conversion import TRTInterpreter
from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes
from torch_tensorrt.dynamo.lowering import (
get_decompositions,
post_lowering,
Expand All @@ -29,6 +32,77 @@
_LOGGER: logging.Logger = logging.getLogger(__name__)


# this method is only used in our converter test to infer the module output dtypes via dummy inference
# which is due to fx.symbolic_trace does not have the meta['val'] info in the node
# TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace
def infer_module_output_dtypes_for_test(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
device: Device,
kwarg_inputs: Optional[dict[str, Any]] = None,
truncate_double: bool = False,
) -> List[dtype]:
"""
This function performs model inference to determine the output dtypes
and truncates them accordingly. inputs can be either arg_inputs or flattened input list.
If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input.
"""
# TODO: We can also determine output dtypes from the module.graph based on node metadata.
# However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata,
# so we stick to the model inference approach currently.
with unset_fake_temporarily():
# Get the device on which the model exists
# For large models, this can be done on CPU to save GPU memory allocation for TRT.
device = get_model_device(module)
torch_inputs = get_torch_inputs(inputs, device)
if kwarg_inputs is None:
kwarg_inputs = {}
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
module_outputs = module(*torch_inputs, **torch_kwarg_inputs)
if not isinstance(module_outputs, (list, tuple)):
module_outputs = [module_outputs]

# Int64 outputs can sometimes be generated from within other operators
# such as aten.sum - such outputs can be truncated
output_dtypes = []
for output in module_outputs:
output_ = output
# We don't need to check if output is nested here because the input module will be flattened
if not isinstance(output, torch.Tensor):
if isinstance(output, str):
raise ValueError(
f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)"
)
else:
output_ = torch.tensor(output)

if truncate_double and output_.dtype == dtype.float64:
output_dtypes.append(dtype.float32)
else:
output_dtypes.append(dtype._from(output_.dtype))

return output_dtypes


# this is to enable dynamo tracer as True in the converter test files batch by batch
def get_use_dynamo_tracer(use_dynamo_tracer: Any) -> bool:
lanluo-nvidia marked this conversation as resolved.
Show resolved Hide resolved
# if in our converter tests we specifically set use_dynamo_tracer field, honor it
if use_dynamo_tracer is not None and isinstance(use_dynamo_tracer, bool):
return use_dynamo_tracer
# if in our converter tests, we did not specify use_dynamo_tracer field
import inspect
import os
import re

filename = os.path.basename(inspect.stack()[2].filename)
# enable converter test files which starts with test_a*.py to use dynamo tracer
pattern = re.compile("^test_([a])+")
if pattern.match(filename):
return True
else:
return False


# this method is only used in our converter test to infer the module output dtypes via dummy inference
# which is due to fx.symbolic_trace does not have the meta['val'] info in the node
# TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace
Expand Down Expand Up @@ -277,14 +351,26 @@ def generate_graph(
enable_passes: bool,
propagate_shapes: bool = False,
settings: CompilationSettings = CompilationSettings(),
torch_export_dynamic_shapes: Optional[Any] = None,
):
mod = mod.eval()
if use_dynamo_tracer:
exported_program = torch_tensorrt.dynamo.trace(mod, tuple(original_inputs))
exported_program = pre_export_lowering(exported_program, settings)
exported_program = exported_program.run_decompositions(
get_decompositions(False)
if torch_export_dynamic_shapes is None:
torch_export_dynamic_shapes = get_dynamic_shapes_args(
mod, original_inputs
)
device = default_device()
torch_export_inputs = get_torch_inputs(original_inputs, device)
exported_program = torch.export.export(
lanluo-nvidia marked this conversation as resolved.
Show resolved Hide resolved
mod,
tuple(torch_export_inputs),
dynamic_shapes=torch_export_dynamic_shapes,
)
if enable_passes:
exported_program = pre_export_lowering(exported_program, settings)
exported_program = exported_program.run_decompositions(
get_decompositions(False)
)
fx_module = exported_program.module()
else:
fx_module = torch.fx.symbolic_trace(mod)
Expand Down Expand Up @@ -313,13 +399,15 @@ def run_test(
atol=ATOL,
precision=dtype.f32,
check_dtype=True,
use_dynamo_tracer=False,
use_dynamo_tracer=None,
enable_passes=False,
propagate_shapes=False,
int32_reqd=False,
make_refittable=False,
):

# TODO: lan to remove this and set use_dynamo_traccer to True by default
# once all the converter test files are moved to use_dynamo_tracer
use_dynamo_tracer = get_use_dynamo_tracer(use_dynamo_tracer)
# Previous instance of the interpreter auto-casted 64-bit inputs
# We replicate this behavior here
compilation_settings = CompilationSettings(
Expand Down Expand Up @@ -366,12 +454,18 @@ def run_test(

output_dtypes = None
if check_dtype:
output_dtypes = infer_module_output_dtypes_for_test(
mod,
input_specs,
compilation_settings.device,
truncate_double=compilation_settings.truncate_double,
)
if use_dynamo_tracer:
output_dtypes = infer_module_output_dtypes(
mod,
truncate_double=compilation_settings.truncate_double,
)
else:
output_dtypes = infer_module_output_dtypes_for_test(
mod,
input_specs,
compilation_settings.device,
truncate_double=compilation_settings.truncate_double,
)

_LOGGER.debug(f"Compilation settings: {compilation_settings}")
_LOGGER.debug(f"Inputs: {input_specs}")
Expand Down Expand Up @@ -441,37 +535,47 @@ def run_test_with_dynamic_shape(
rtol=RTOL,
atol=ATOL,
output_dtypes=None,
use_dynamo_tracer=False,
use_dynamo_tracer=None,
enable_passes=False,
use_example_tensors=True,
pyt_inputs=None,
propagate_shapes=False,
check_dtype=True,
make_refittable=False,
torch_export_dynamic_shapes=None,
):
# TODO: lan to remove this and set use_dynamo_traccer to True by default
# once all the converter test files are moved to use_dynamo_tracer
use_dynamo_tracer = get_use_dynamo_tracer(use_dynamo_tracer)

# Previous instance of the interpreter auto-casted 64-bit inputs
# We replicate this behavior here
compilation_settings = CompilationSettings(
truncate_double=True, make_refittable=make_refittable
)

mod = self.generate_graph(
mod,
input_specs,
use_dynamo_tracer=use_dynamo_tracer,
enable_passes=enable_passes,
propagate_shapes=propagate_shapes,
settings=compilation_settings,
torch_export_dynamic_shapes=torch_export_dynamic_shapes,
)

if check_dtype:
output_dtypes = infer_module_output_dtypes_for_test(
mod,
input_specs,
compilation_settings.device,
truncate_double=compilation_settings.truncate_double,
)
if use_dynamo_tracer:
output_dtypes = infer_module_output_dtypes(
mod,
truncate_double=compilation_settings.truncate_double,
)
else:
output_dtypes = infer_module_output_dtypes_for_test(
mod,
input_specs,
compilation_settings.device,
truncate_double=compilation_settings.truncate_double,
)

interp = TRTInterpreter(
mod,
Expand Down
2 changes: 1 addition & 1 deletion tests/py/dynamo/conversion/test_acos_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(self, input):
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(2, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
Expand Down
4 changes: 2 additions & 2 deletions tests/py/dynamo/conversion/test_acosh_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ def forward(self, input):
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(2, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
(
"3d_dim_dtype_int32",
(1, 1, 1),
(1, 2, 4),
(2, 2, 4),
(2, 3, 5),
torch.int32,
torch.float,
Expand Down
8 changes: 4 additions & 4 deletions tests/py/dynamo/conversion/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class TestAnyConverterDynamic(DispatchTestCase):
(
"3d_dynamic_float",
(2, 1, 1),
(2, 2, 1),
(2, 2, 2),
(3, 2, 4),
torch.float,
),
Expand Down Expand Up @@ -234,7 +234,7 @@ def forward(self, x):
(
"3d_dynamic_dim_float",
(2, 1, 1),
(2, 2, 1),
(2, 2, 2),
(3, 2, 4),
torch.float,
2,
Expand All @@ -252,7 +252,7 @@ def forward(self, x):
(
"3d_dynamic_dim_bool",
(2, 1, 1),
(2, 2, 1),
(2, 2, 2),
(3, 2, 4),
torch.bool,
0,
Expand Down Expand Up @@ -285,7 +285,7 @@ def forward(self, x):
(
"3d_dynamic_dims_float",
(2, 1, 1),
(2, 2, 1),
(2, 2, 2),
(3, 2, 4),
torch.float,
[1, 2],
Expand Down
1 change: 1 addition & 0 deletions tests/py/dynamo/conversion/test_arange_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def forward(self, end_tensor):
use_example_tensors=False,
check_dtype=False,
pyt_inputs=[pyt_input],
use_dynamo_tracer=False,
)


Expand Down
2 changes: 1 addition & 1 deletion tests/py/dynamo/conversion/test_asin_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(self, input):
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(2, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
Expand Down
4 changes: 2 additions & 2 deletions tests/py/dynamo/conversion/test_asinh_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ def forward(self, input):
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(2, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
(
"3d_dim_dtype_int32",
(1, 1, 1),
(1, 2, 4),
(2, 2, 4),
(2, 3, 5),
torch.int32,
torch.float,
Expand Down
22 changes: 17 additions & 5 deletions tests/py/dynamo/conversion/test_atan2_aten.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import unittest

import torch
import torch.nn as nn
from parameterized import parameterized
Expand Down Expand Up @@ -141,15 +143,15 @@ def forward(self, lhs_val, rhs_val):
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(2, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
(
"3d_dim_dtype_int32",
(1, 1, 1),
(1, 2, 4),
(2, 2, 4),
(2, 3, 5),
torch.int32,
torch.float,
Expand Down Expand Up @@ -182,10 +184,17 @@ def forward(self, lhs_val, rhs_val):
)


# torch.ops.aten.atan2.out will be decomposed/partitioned into core aten ops which torch_tensorrt supported in run_on_acc and
# non supported ops in run_on_gpu in dynamo tracer, it works via torch_tensorrt.dynamo.compile workflow
# but it won't be valid for our converter test framework, so skip it here.
@unittest.skip("skip torch.ops.aten.atan2.out converter test")
class TestAtan2OutConverter(DispatchTestCase):
@parameterized.expand(
[
((10,), (5,), torch.float),
# dynamo trace does not allow output to be in a different shape
# raise Unsupported(msg, case_name=case_name)
# torch._dynamo.exc.Unsupported: out variants with resizing on graph inputs
((5,), (5,), torch.float),
((10,), (10,), torch.float),
]
)
Expand Down Expand Up @@ -220,7 +229,7 @@ def forward(self, lhs_val, rhs_val, out):
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(2, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
Expand Down Expand Up @@ -255,7 +264,10 @@ def forward(self, lhs_val, rhs_val, out):
),
]
self.run_test_with_dynamic_shape(
atan2(), input_specs, output_dtypes=[output_type]
atan2(),
input_specs,
output_dtypes=[output_type],
use_dynamo_tracer=False,
)


Expand Down
Loading
Loading