Skip to content

Commit

Permalink
small fix: Index validator enable int64 (#2642)
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive authored Feb 6, 2024
1 parent e38a7f3 commit ffbcc7a
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 10 deletions.
7 changes: 5 additions & 2 deletions examples/dynamo/torch_compile_advanced_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
# For the default settings, we can simply call torch.compile
# with the backend "torch_tensorrt", and run the model on an
# input to cause compilation, as so:
optimized_model = torch.compile(model, backend="torch_tensorrt")
optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False)
optimized_model(*sample_inputs)

# %%
Expand Down Expand Up @@ -81,7 +81,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):

# Run the model on an input to cause compilation, as so:
optimized_model_custom = torch.compile(
model_half, backend="torch_tensorrt", options=backend_kwargs
model_half,
backend="torch_tensorrt",
options=backend_kwargs,
dynamic=False,
)
optimized_model_custom(*sample_inputs_half)

Expand Down
1 change: 1 addition & 0 deletions examples/dynamo/torch_compile_transformers_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
optimized_model = torch.compile(
model,
backend="torch_tensorrt",
dynamic=False,
options=compilation_kwargs,
)
optimized_model(*inputs)
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def index_dtype_validator(node: Node) -> bool:
for ind in index:
if ind is not None:
val = ind.meta.get("val")
if val is not None and val.dtype != torch.int32:
if val is not None and val.dtype not in (torch.int32, torch.int64):
return False
return True

Expand Down
9 changes: 2 additions & 7 deletions tests/py/dynamo/conversion/test_index_aten.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import operator

import torch
import torch.nn as nn
from .harness import DispatchTestCase
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestIndexConverter(DispatchTestCase):
Expand All @@ -15,7 +13,6 @@ def __init__(self):
super().__init__()

def forward(self, x):
index0 = torch.randint(0, 1, (1, 1))
indices = [None, self.index0]
out = torch.ops.aten.index.Tensor(x, indices)
return out
Expand Down Expand Up @@ -158,8 +155,6 @@ def __init__(self):
super().__init__()

def forward(self, x):
index0 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
index1 = index0.unsqueeze(0).T.long()
indices = [None, None, self.index0, self.index1]
out = torch.ops.aten.index.Tensor(x, indices)
return out
Expand Down

0 comments on commit ffbcc7a

Please sign in to comment.