diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 7d133ceffa..b65f95f0e5 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -181,8 +181,7 @@ def cast_int_int_div_trt_tensor( def broadcastable( - a: TRTTensor, - b: TRTTensor, + a: Union[TRTTensor, np.ndarray], b: Union[TRTTensor, np.ndarray] ) -> bool: "Check if two tensors are broadcastable according to torch rules" a_shape = tuple(a.shape) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 70f94cdca8..db586be65f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -3,6 +3,7 @@ import numpy as np import tensorrt as trt +import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -80,14 +81,20 @@ def index( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - index: Union[TRTTensor, Sequence[TRTTensor]], + index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], ) -> TRTTensor: adv_indx_indices = [] tensor_indices = [] - # _LOGGER.debug(f"The index shape is {index.shape}") # check if the input is dynamic dynamic_shape = has_dynamic_shape(input.shape) - + # is_numpy is a flag to specify if all the indices are numpy or torchTensor. + # If any is not this flag will be set to False + _LOGGER.debug( + f"Determining whether aten.index constant-index optimization can be invoked" + ) + is_numpy = all( + isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None + ) # here we need to check if all the index are broadcastable # if no, then we need to broadcast last_index = None @@ -95,8 +102,13 @@ def index( if ind is not None: _LOGGER.debug(f"Shape of {i} index is {ind.shape}") adv_indx_indices.append(i) - # torch.nn.parameter.Parameter=> torch.Tensor - ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}") + # torch.nn.parameter.Parameter=> numpy array + # numpy array is kept as numpy + # other cases are kept as TRTTensor + if is_numpy: + ind = to_numpy(ind) + else: + ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}") if last_index is not None: assert broadcastable( ind, last_index @@ -110,8 +122,9 @@ def index( set_layer_name(identity_layer, target, name + "_index_identity", source_ir) return identity_layer.get_output(0) elif len(tensor_indices) == 1: - # This case works - indices_tensor = tensor_indices[0] + indices_tensor = get_trt_tensor( + ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor" + ) index = adv_indx_indices[0] _LOGGER.debug(f"The advanced index indices is {adv_indx_indices}") gather_layer = ctx.net.add_gather(input, indices_tensor, index) @@ -150,6 +163,7 @@ def index( if i not in adv_indx_indices: new_order.append(i) _LOGGER.debug(f"The new transpose order is {new_order}") + transpose_layer.second_transpose = tuple(new_order) set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir) transpose_tensor = transpose_layer.get_output(0) @@ -175,47 +189,58 @@ def index( concat_tensor = concat_tensor_layer.get_output(0) reshape_layer = ctx.net.add_shuffle(transpose_tensor) - # check this reshape_layer.set_input(1, concat_tensor) flatten_tensor = reshape_layer.get_output(0) + _LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}") # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the # // j dimension of input x. - multiplier = get_trt_tensor( - ctx, - dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], - name + "_dim_last", - ) - cum_adv_index = tensor_indices[adv_indx_count - 1] - for i in range(adv_indx_count - 2, -1, -1): - adv_index = convert_binary_elementwise( - ctx, - target, - source_ir, - name + f"_index_intermediate_{i}", - trt.ElementWiseOperation.PROD, - multiplier, - tensor_indices[i], + if is_numpy: + multiplier = input_shape[adv_indx_indices[adv_indx_count - 1]] + cum_adv_index = tensor_indices[adv_indx_count - 1] + for i in range(adv_indx_count - 2, -1, -1): + adv_index = multiplier * tensor_indices[i] + cum_adv_index = cum_adv_index + adv_index + multiplier = multiplier * input_shape[adv_indx_indices[i]] + cum_adv_index = get_trt_tensor( + ctx, cum_adv_index, name + f"_index_sum_intermediate" ) - cum_adv_index = convert_binary_elementwise( - ctx, - target, - source_ir, - name + f"_index_sum_intermediate_{i}", - trt.ElementWiseOperation.SUM, - cum_adv_index, - adv_index, - ) - multiplier = convert_binary_elementwise( + else: + multiplier = get_trt_tensor( ctx, - target, - source_ir, - name + f"_index_intermediate_xj_{i}", - trt.ElementWiseOperation.PROD, - multiplier, - dim_tensor_list[adv_indx_indices[i]], + dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], + name + "_dim_last", ) + cum_adv_index = tensor_indices[adv_indx_count - 1] + for i in range(adv_indx_count - 2, -1, -1): + adv_index = convert_binary_elementwise( + ctx, + target, + source_ir, + name + f"_index_intermediate_{i}", + trt.ElementWiseOperation.PROD, + multiplier, + tensor_indices[i], + ) + cum_adv_index = convert_binary_elementwise( + ctx, + target, + source_ir, + name + f"_index_sum_intermediate_{i}", + trt.ElementWiseOperation.SUM, + cum_adv_index, + adv_index, + ) + multiplier = convert_binary_elementwise( + ctx, + target, + source_ir, + name + f"_index_intermediate_xj_{i}", + trt.ElementWiseOperation.PROD, + multiplier, + dim_tensor_list[adv_indx_indices[i]], + ) gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0) set_layer_name( diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 393eb53c63..df61a4b835 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -2,11 +2,10 @@ import torch import torch.nn as nn +from .harness import DispatchTestCase from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input -from .harness import DispatchTestCase - class TestIndexConverter(DispatchTestCase): def test_index_zero_two_dim(self): @@ -27,6 +26,21 @@ def forward(self, x): input, ) + def test_index_zero_two_dim_ITensor(self): + class TestModule(nn.Module): + def forward(self, x, index0): + indices = [None, index0] + out = torch.ops.aten.index.Tensor(x, indices) + return out + + input = torch.randn(2, 2) + index0 = torch.randint(0, 1, (1, 1)) + index0 = index0.to(torch.int32) + self.run_test( + TestModule(), + [input, index0], + ) + def test_index_zero_index_three_dim(self): class TestModule(nn.Module): def __init__(self): @@ -44,6 +58,18 @@ def forward(self, x): input, ) + def test_index_zero_index_three_dim_ITensor(self): + class TestModule(nn.Module): + def forward(self, x, index0): + indices = [None, index0, None] + out = torch.ops.aten.index.Tensor(x, indices) + return out + + input = torch.randn(2, 2, 2) + index0 = torch.randint(0, 1, (1, 1)) + index0 = index0.to(torch.int32) + self.run_test(TestModule(), [input, index0]) + def test_index_zero_index_one_index_two_three_dim(self): class TestModule(nn.Module): def __init__(self):