Skip to content

Commit

Permalink
Numpy changes for aten::index converter (#2396)
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose authored Nov 8, 2023
1 parent 88f6812 commit 7029e91
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 43 deletions.
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,7 @@ def cast_int_int_div_trt_tensor(


def broadcastable(
a: TRTTensor,
b: TRTTensor,
a: Union[TRTTensor, np.ndarray], b: Union[TRTTensor, np.ndarray]
) -> bool:
"Check if two tensors are broadcastable according to torch rules"
a_shape = tuple(a.shape)
Expand Down
103 changes: 64 additions & 39 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
Expand Down Expand Up @@ -80,23 +81,34 @@ def index(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
index: Union[TRTTensor, Sequence[TRTTensor]],
index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
) -> TRTTensor:
adv_indx_indices = []
tensor_indices = []
# _LOGGER.debug(f"The index shape is {index.shape}")
# check if the input is dynamic
dynamic_shape = has_dynamic_shape(input.shape)

# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
# If any is not this flag will be set to False
_LOGGER.debug(
f"Determining whether aten.index constant-index optimization can be invoked"
)
is_numpy = all(
isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None
)
# here we need to check if all the index are broadcastable
# if no, then we need to broadcast
last_index = None
for i, ind in enumerate(index):
if ind is not None:
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")
adv_indx_indices.append(i)
# torch.nn.parameter.Parameter=> torch.Tensor
ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}")
# torch.nn.parameter.Parameter=> numpy array
# numpy array is kept as numpy
# other cases are kept as TRTTensor
if is_numpy:
ind = to_numpy(ind)
else:
ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}")
if last_index is not None:
assert broadcastable(
ind, last_index
Expand All @@ -110,8 +122,9 @@ def index(
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
return identity_layer.get_output(0)
elif len(tensor_indices) == 1:
# This case works
indices_tensor = tensor_indices[0]
indices_tensor = get_trt_tensor(
ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor"
)
index = adv_indx_indices[0]
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
gather_layer = ctx.net.add_gather(input, indices_tensor, index)
Expand Down Expand Up @@ -150,6 +163,7 @@ def index(
if i not in adv_indx_indices:
new_order.append(i)
_LOGGER.debug(f"The new transpose order is {new_order}")

transpose_layer.second_transpose = tuple(new_order)
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
transpose_tensor = transpose_layer.get_output(0)
Expand All @@ -175,47 +189,58 @@ def index(
concat_tensor = concat_tensor_layer.get_output(0)

reshape_layer = ctx.net.add_shuffle(transpose_tensor)
# check this
reshape_layer.set_input(1, concat_tensor)
flatten_tensor = reshape_layer.get_output(0)

_LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}")

# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
# // j dimension of input x.
multiplier = get_trt_tensor(
ctx,
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
name + "_dim_last",
)
cum_adv_index = tensor_indices[adv_indx_count - 1]
for i in range(adv_indx_count - 2, -1, -1):
adv_index = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_intermediate_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
tensor_indices[i],
if is_numpy:
multiplier = input_shape[adv_indx_indices[adv_indx_count - 1]]
cum_adv_index = tensor_indices[adv_indx_count - 1]
for i in range(adv_indx_count - 2, -1, -1):
adv_index = multiplier * tensor_indices[i]
cum_adv_index = cum_adv_index + adv_index
multiplier = multiplier * input_shape[adv_indx_indices[i]]
cum_adv_index = get_trt_tensor(
ctx, cum_adv_index, name + f"_index_sum_intermediate"
)
cum_adv_index = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_sum_intermediate_{i}",
trt.ElementWiseOperation.SUM,
cum_adv_index,
adv_index,
)
multiplier = convert_binary_elementwise(
else:
multiplier = get_trt_tensor(
ctx,
target,
source_ir,
name + f"_index_intermediate_xj_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
dim_tensor_list[adv_indx_indices[i]],
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
name + "_dim_last",
)
cum_adv_index = tensor_indices[adv_indx_count - 1]
for i in range(adv_indx_count - 2, -1, -1):
adv_index = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_intermediate_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
tensor_indices[i],
)
cum_adv_index = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_sum_intermediate_{i}",
trt.ElementWiseOperation.SUM,
cum_adv_index,
adv_index,
)
multiplier = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_intermediate_xj_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
dim_tensor_list[adv_indx_indices[i]],
)

gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0)
set_layer_name(
Expand Down
30 changes: 28 additions & 2 deletions tests/py/dynamo/conversion/test_index_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import torch
import torch.nn as nn
from .harness import DispatchTestCase
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestIndexConverter(DispatchTestCase):
def test_index_zero_two_dim(self):
Expand All @@ -27,6 +26,21 @@ def forward(self, x):
input,
)

def test_index_zero_two_dim_ITensor(self):
class TestModule(nn.Module):
def forward(self, x, index0):
indices = [None, index0]
out = torch.ops.aten.index.Tensor(x, indices)
return out

input = torch.randn(2, 2)
index0 = torch.randint(0, 1, (1, 1))
index0 = index0.to(torch.int32)
self.run_test(
TestModule(),
[input, index0],
)

def test_index_zero_index_three_dim(self):
class TestModule(nn.Module):
def __init__(self):
Expand All @@ -44,6 +58,18 @@ def forward(self, x):
input,
)

def test_index_zero_index_three_dim_ITensor(self):
class TestModule(nn.Module):
def forward(self, x, index0):
indices = [None, index0, None]
out = torch.ops.aten.index.Tensor(x, indices)
return out

input = torch.randn(2, 2, 2)
index0 = torch.randint(0, 1, (1, 1))
index0 = index0.to(torch.int32)
self.run_test(TestModule(), [input, index0])

def test_index_zero_index_one_index_two_three_dim(self):
class TestModule(nn.Module):
def __init__(self):
Expand Down

0 comments on commit 7029e91

Please sign in to comment.