Skip to content

Commit

Permalink
feat: support chunk dynamo converter (#2401)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Oct 25, 2023
1 parent cb20f90 commit 533215c
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 0 deletions.
24 changes: 24 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,30 @@ def aten_ops_slice(
)


@dynamo_tensorrt_converter(torch.ops.aten.chunk.default) # type: ignore[misc]
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
def aten_ops_chunk(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.slice.chunk(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args_bounds_check(args, 2, 0),
)


@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
@enforce_tensor_types(
{
Expand Down
55 changes: 55 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,58 @@ def expand(
layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def chunk(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
chunks: int,
dim: int,
) -> TRTTensor:
if chunks <= 0:
raise RuntimeError(
f"chunk expects `chunks` to be greater than 0, got: {chunks}"
)

shape = input.shape
dim = get_positive_dim(dim, len(shape))

if dim >= len(shape):
raise RuntimeError(
f"chunk expects `dim` to be less than the length of input shape, got: {dim}"
)

dynamic_shape = has_dynamic_shape(input.shape)
if dynamic_shape > 0:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"

size_dim = shape[dim]
chunk_size = math.ceil(size_dim / chunks)
result = []
start = 0
end = min(start + chunk_size, size_dim)
cnt = 0

while start < end:
result.append(
slice_op(
ctx,
target,
source_ir,
f"{name}_slice_{cnt}",
input,
dim,
start,
end,
1,
)
)
start = end
end = min(start + chunk_size, size_dim)
cnt += 1

return result
82 changes: 82 additions & 0 deletions tests/py/dynamo/conversion/test_chunk_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestChunkConverter(DispatchTestCase):
@parameterized.expand(
[
((1,), 3, 0),
((3,), 3, 0),
((4,), 3, 0),
((6,), 3, 0),
((3,), 1, -1),
((3,), 3, -1),
((3,), 4, -1),
]
)
def test_chunk_1D(self, shape, chunks, dim):
class TestChunk(torch.nn.Module):
def forward(self, input):
out = torch.ops.aten.chunk.default(input, chunks, dim)
return out

input = [torch.randn(shape)]
self.run_test(
TestChunk(),
input,
)

@parameterized.expand(
[
((3, 4), 1, 0),
((3, 4), 3, 0),
((3, 4), 4, 0),
((3, 4), 2, -2),
((3, 4), 6, -2),
((3, 4), 3, 1),
((3, 4), 4, 1),
((3, 4), 5, -1),
]
)
def test_chunk_2D(self, shape, chunks, dim):
class TestChunk(torch.nn.Module):
def forward(self, input):
out = torch.ops.aten.chunk.default(input, chunks, dim)
return out

input = [torch.randn(shape)]
self.run_test(
TestChunk(),
input,
)

@parameterized.expand(
[
((3, 4, 2), 1, 0),
((3, 4, 2), 3, -3),
((3, 4, 2), 3, 1),
((3, 4, 2), 4, 1),
((3, 4, 2), 6, -2),
((3, 4, 2), 1, 2),
((3, 4, 2), 3, -1),
((3, 4, 2), 4, -1),
]
)
def test_chunk_3D(self, shape, chunks, dim):
class TestChunk(torch.nn.Module):
def forward(self, input):
out = torch.ops.aten.chunk.default(input, chunks, dim)
return out

input = [torch.randn(shape)]
self.run_test(
TestChunk(),
input,
)


if __name__ == "__main__":
run_tests()

0 comments on commit 533215c

Please sign in to comment.