Skip to content

Commit

Permalink
feat: support aten.flip dynamo converter (#2540)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Jan 15, 2024
1 parent b8403b8 commit fd19353
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 0 deletions.
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2626,3 +2626,26 @@ def aten_ops_pdist(
args[0],
args_bounds_check(args, 1, 2),
)


@dynamo_tensorrt_converter(torch.ops.aten.flip.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_flip(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.slice.flip(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)
34 changes: 34 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,37 @@ def tile(
layer.mode = trt.SampleMode.WRAP
set_layer_name(layer, target, name)
return layer.get_output(0)


def flip(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dims: Sequence[int],
) -> TRTTensor:
start_slice = []
output_shape = list(input.shape)
stride_slice = []

shape = input.shape
rank = len(shape)
dims = get_positive_dim(dims, rank)

for i in range(rank):
if i in dims:
start_slice.append(shape[i] - 1)
stride_slice.append(-1)
else:
start_slice.append(0)
stride_slice.append(1)

layer = ctx.net.add_slice(
input,
start=start_slice,
shape=output_shape,
stride=stride_slice,
)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
37 changes: 37 additions & 0 deletions tests/py/dynamo/conversion/test_flip_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestFlipConverter(DispatchTestCase):
@parameterized.expand(
[
((3,), [0]),
((3,), [-1]),
((3,), []),
((3, 3), [0, 1]),
((3, 3), [-2, 1]),
((2, 3, 4), [0]),
((3, 3, 3), (0, 1)),
((2, 3, 4), [0, 1, 2]),
((2, 3, 4), [-3, -2, -1]),
((3, 3, 3, 3), [0]),
((2, 3, 4, 5), [0, 1, 2, 3]),
((2, 3, 4, 5), [-4, 1, -2, 3]),
((2, 3, 4, 5), []),
]
)
def test_flip(self, shape, dims):
class Flip(nn.Module):
def forward(self, x):
return torch.ops.aten.flip.default(x, dims)

inputs = [torch.randn(shape)]
self.run_test(Flip(), inputs)


if __name__ == "__main__":
run_tests()

0 comments on commit fd19353

Please sign in to comment.