From 8fdaaf59c5e0d1347d05f15018861e79fe7012a8 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 14 Aug 2023 08:45:09 -0700 Subject: [PATCH] Type mismatch for dynamo converter --- py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py | 6 +++--- tests/py/dynamo/converters/test_where_aten.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py index 1b46f106b8..b81418490c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py @@ -65,7 +65,7 @@ def where( condition_val = condition_layer.get_output(0) else: assert condition.dtype == trt.bool, "mask dtype is not bool!" - if condition_shape != condition_dim: # TODO: What is this checking? + if len(condition_shape) != condition_dim: condition_val = expand( network, target, source_ir, f"{name}_expand", condition, output_shape ) @@ -73,7 +73,7 @@ def where( condition_val = condition if type(input) != TRTTensor: - if x_shape != input_dim: # TODO: What is this checking? + if x_shape != output_shape: # special case where 1 element in input if len(input.shape) == 0: input = input.unsqueeze(0) @@ -95,7 +95,7 @@ def where( y_val = get_trt_tensor(network, other, f"{name}_y") else: y_val = other - if y_shape != other_dim: # TODO: What is this checking? + if y_shape != output_shape: y_val = expand( network, target, source_ir, f"{name}_y_expand", y_val, output_shape ) diff --git a/tests/py/dynamo/converters/test_where_aten.py b/tests/py/dynamo/converters/test_where_aten.py index ddeb269ee9..f15477092d 100644 --- a/tests/py/dynamo/converters/test_where_aten.py +++ b/tests/py/dynamo/converters/test_where_aten.py @@ -12,6 +12,7 @@ class TestWhereConverter(DispatchTestCase): ("2d_broadcast_condition_xshape_yshape", (2, 2), (2, 1)), ("3d_condition_xshape_yshape", (2, 2, 1), (2, 2, 1)), ("2d_3d_condition_xshape_yshape", (2, 2), (1, 2, 2)), + ("3d_2d_condition_xshape_yshape", (1, 2, 2), (2, 2)), ] ) def test_(self, _, x_size, y_size):