From 8803be82a62d48c14e8a34c86dfb97df1baca812 Mon Sep 17 00:00:00 2001 From: ganler Date: Sat, 12 Mar 2022 20:45:24 -0600 Subject: [PATCH 1/2] fix flatten --- python/tvm/relay/frontend/onnx.py | 4 ++-- tests/python/frontend/onnx/test_forward.py | 25 +++++++++++----------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 12673c5303b9..a751f23fe732 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1138,7 +1138,7 @@ class Flatten(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): axis = attr.get("axis", 1) - ishape = _op.shape_of(inputs[0]) + ishape = shape_of(inputs[0]) ndim = infer_shape(ishape)[0] if axis < 0: axis = axis + ndim @@ -1148,7 +1148,7 @@ def _impl_v1(cls, inputs, attr, params): else: pre_shape = _op.prod(_op.strided_slice(ishape, [0], [axis], [1]), keepdims=True) post_shape = _op.prod(_op.strided_slice(ishape, [axis], [ndim], [1]), keepdims=True) - newshape = _op.concatenate([pre_shape, post_shape], axis=0) + newshape = fold_constant(_op.concatenate([pre_shape, post_shape], axis=0)) out = _op.reshape(inputs[0], newshape) return out diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index b62509297300..699b87dc7043 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -576,22 +576,21 @@ def test_squeeze(target, dev): @tvm.testing.parametrize_targets def test_flatten(target, dev): + def verify_flatten(in_shape, axis, ref_shape): + flatten_node = helper.make_node("Flatten", ["in"], ["out"], axis=axis) - in_shape = (1, 3, 4, 4) - axis = 1 - ref_shape = (1, 48) - - flatten_node = helper.make_node("Flatten", ["in"], ["out"], axis=axis) + graph = helper.make_graph( + [flatten_node], + "flatten_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_shape))], + ) - graph = helper.make_graph( - [flatten_node], - "flatten_test", - inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_shape))], - ) + model = helper.make_model(graph, producer_name="flatten_test") + verify_with_ort(model, [in_shape], target=target, dev=dev) - model = helper.make_model(graph, producer_name="flatten_test") - verify_with_ort(model, [in_shape], target=target, dev=dev) + verify_flatten((1, 3, 4, 4), 1, (1, 48)) + verify_flatten((1), 1, (1, 1)) @tvm.testing.parametrize_targets From 3c2ec3c6b5a6eba19876d9e5597ab57d459d415c Mon Sep 17 00:00:00 2001 From: ganler Date: Sun, 13 Mar 2022 14:07:51 -0500 Subject: [PATCH 2/2] fix: python tuple to list --- tests/python/frontend/onnx/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 699b87dc7043..a4631e762f6f 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -590,7 +590,7 @@ def verify_flatten(in_shape, axis, ref_shape): verify_with_ort(model, [in_shape], target=target, dev=dev) verify_flatten((1, 3, 4, 4), 1, (1, 48)) - verify_flatten((1), 1, (1, 1)) + verify_flatten((1,), 1, (1, 1)) @tvm.testing.parametrize_targets