Skip to content

Commit

Permalink
fix: Address runtimes with 0D inputs (#2188)
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive authored Aug 22, 2023
1 parent 56b8950 commit e58f319
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 7 deletions.
2 changes: 1 addition & 1 deletion core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
TORCHTRT_CHECK(
inputs[i].dtype() == expected_type,
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
auto dims = core::util::toDimsPad(inputs[i].sizes(), 1);
auto dims = core::util::toDims(inputs[i].sizes());
auto shape = core::util::toVec(dims);
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims);
Expand Down
8 changes: 4 additions & 4 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,18 +339,18 @@ def from_tensor(
A Input object.
"""
if not (
t.is_contiguous(memory_format=torch.contiguous_format)
disable_memory_format_check
or t.is_contiguous(memory_format=torch.contiguous_format)
or t.is_contiguous(memory_format=torch.channels_last)
or disable_memory_format_check
):
raise ValueError(
"Tensor does not have a supported memory format, supported formats are contiguous or channel_last"
)
frmt = (
torch.contiguous_format
if (
t.is_contiguous(memory_format=torch.contiguous_format)
or disable_memory_format_check
disable_memory_format_check
or t.is_contiguous(memory_format=torch.contiguous_format)
)
else torch.channels_last
)
Expand Down
42 changes: 40 additions & 2 deletions tests/py/dynamo/backend/test_specialized_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from utils import lower_graph_testing
from torch.testing._internal.common_utils import run_tests, TestCase
import torch
import torch_tensorrt
from torch.testing._internal.common_utils import TestCase, run_tests
from utils import lower_graph_testing


class TestFakeTensors(TestCase):
Expand Down Expand Up @@ -118,5 +118,43 @@ def forward(self, x):
torch._dynamo.reset()


class Test0DTensors(TestCase):
def test_0D_input(self):
class Tensor0DInput(torch.nn.Module):
def forward(self, x):
return x * 7

inputs = [
torch.tensor(
3,
)
.cuda()
.int(),
]

fx_graph = torch.fx.symbolic_trace(Tensor0DInput())

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
msg=f"0D-Tensor TRT outputs don't match with the original model.",
)
torch._dynamo.reset()


if __name__ == "__main__":
run_tests()

0 comments on commit e58f319

Please sign in to comment.