Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support flatten and reshape via shuffle_layer #2354

Merged
merged 11 commits into from
Oct 6, 2023
24 changes: 24 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,3 +1555,27 @@ def tensorrt_scaled_dot_product_attention(
return impl.attention.scaled_dot_product_attention(
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
)


@dynamo_tensorrt_converter(torch.ops.aten.reshape.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.view.default) # type: ignore[misc]
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
def aten_ops_reshape(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.shuffle.reshape(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
shape=args[1],
)
32 changes: 32 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,35 @@ def to_numpy(
raise AssertionError(
f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}"
)


def flatten_dims(
zewenli98 marked this conversation as resolved.
Show resolved Hide resolved
input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]],
start_dim: int,
end_dim: int,
) -> Tuple[int, ...]:
"""
Given an input, start and end indices of dimension,
this function will return a flattened new shape.

Args:
input (Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]]):
an input value waiting to be flattened
start_dim (int): the first dim to flatten
end_dim (int): the last dim to flatten (this dim is included)

Returns:
Tuple[int]: new_shape
"""
shape = input.shape
dim_size = len(shape)
start_dim = get_positive_dim(start_dim, dim_size)
end_dim = get_positive_dim(end_dim, dim_size)

num_elements = 1
for i in range(start_dim, end_dim + 1):
num_elements *= shape[i]

new_shape = tuple(shape[:start_dim]) + (num_elements,) + tuple(shape[end_dim + 1 :])

return new_shape
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
reduce,
select,
shape,
shuffle,
slice,
split,
squeeze,
Expand Down
21 changes: 21 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Optional, Sequence, Union

from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor


def reshape(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
shape: Sequence[int],
) -> TRTTensor:
layer = ctx.net.add_shuffle(input)
layer.reshape_dims = tuple(shape)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
40 changes: 39 additions & 1 deletion tests/py/dynamo/conversion/test_converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import numpy as np
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo.conversion.converter_utils import enforce_tensor_types
from torch_tensorrt.dynamo.conversion.converter_utils import (
enforce_tensor_types,
flatten_dims,
)
from torch_tensorrt.fx.types import TRTTensor

from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
Expand Down Expand Up @@ -37,5 +41,39 @@ def test_invalid_invocation_type(self):
enforce_tensor_types({0: (int, bool)})


class TestFlattenDimsEnforcement(TestCase):
@parameterized.expand(
[
((1, 2), 0, 0, (1, 2)),
((1, 2), 0, 1, (2,)),
((2, 3, 4), 1, 2, (2, 12)),
((2, 3, 4), 0, 1, (6, 4)),
((2, 3, 4), -3, 2, (24,)),
((2, 3, 4, 5), 0, -2, (24, 5)),
((2, 3, 4, 5), -4, -1, (120,)),
]
)
def test_numpy_array(self, input_shape, start_dim, end_dim, true_shape):
inputs = np.random.randn(*input_shape)
new_shape = flatten_dims(inputs, start_dim, end_dim)
self.assertEqual(new_shape, true_shape)

@parameterized.expand(
[
((1, 2), 0, 0, (1, 2)),
((1, 2), 0, 1, (2,)),
((2, 3, 4), 1, 2, (2, 12)),
((2, 3, 4), 0, 1, (6, 4)),
((2, 3, 4), -3, 2, (24,)),
((2, 3, 4, 5), 0, -2, (24, 5)),
((2, 3, 4, 5), -4, -1, (120,)),
]
)
def test_torch_tensor(self, input_shape, start_dim, end_dim, true_shape):
inputs = torch.randn(input_shape)
new_shape = flatten_dims(inputs, start_dim, end_dim)
self.assertEqual(new_shape, true_shape)


if __name__ == "__main__":
run_tests()
4 changes: 4 additions & 0 deletions tests/py/dynamo/conversion/test_reshape_aten.py
zewenli98 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
class TestReshapeConverter(DispatchTestCase):
@parameterized.expand(
[
((-1,),),
((20,),),
((1, 20),),
((1, 10, -1),),
]
Expand All @@ -37,6 +39,7 @@ def forward(self, x):

@parameterized.expand(
[
((-1,),),
((-1, 10),),
((-1, 5),),
((2, 2, -1),),
Expand Down Expand Up @@ -65,6 +68,7 @@ def forward(self, x):
self.run_test_with_dynamic_shape(
TestModule(target_shape),
input_specs,
expected_ops={torch.ops.aten.view.default},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line can be removed, in accordance with the new testing PR

)


Expand Down
Loading