Skip to content

Commit

Permalink
Merge pull request #2198 from pytorch/dynamo_converter_where_type_mis…
Browse files Browse the repository at this point in the history
…match

Type mismatch for dynamo aten::where converter
  • Loading branch information
narendasan authored Aug 21, 2023
2 parents 91fcea4 + 8fdaaf5 commit 148b3ba
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ def where(
condition_val = condition_layer.get_output(0)
else:
assert condition.dtype == trt.bool, "mask dtype is not bool!"
if condition_shape != condition_dim: # TODO: What is this checking?
if len(condition_shape) != condition_dim:
condition_val = expand(
network, target, source_ir, f"{name}_expand", condition, output_shape
)
else:
condition_val = condition

if type(input) != TRTTensor:
if x_shape != input_dim: # TODO: What is this checking?
if x_shape != output_shape:
# special case where 1 element in input
if len(input.shape) == 0:
input = input.unsqueeze(0)
Expand All @@ -95,7 +95,7 @@ def where(
y_val = get_trt_tensor(network, other, f"{name}_y")
else:
y_val = other
if y_shape != other_dim: # TODO: What is this checking?
if y_shape != output_shape:
y_val = expand(
network, target, source_ir, f"{name}_y_expand", y_val, output_shape
)
Expand Down
1 change: 1 addition & 0 deletions tests/py/dynamo/converters/test_where_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class TestWhereConverter(DispatchTestCase):
("2d_broadcast_condition_xshape_yshape", (2, 2), (2, 1)),
("3d_condition_xshape_yshape", (2, 2, 1), (2, 2, 1)),
("2d_3d_condition_xshape_yshape", (2, 2), (1, 2, 2)),
("3d_2d_condition_xshape_yshape", (1, 2, 2), (2, 2)),
]
)
def test_(self, _, x_size, y_size):
Expand Down

0 comments on commit 148b3ba

Please sign in to comment.