From 31a4267a1960754a91e2a5189a516598029b26b8 Mon Sep 17 00:00:00 2001 From: Jiawei Liu Date: Fri, 25 Mar 2022 14:51:11 -0500 Subject: [PATCH] [ONNX] fix reduce crash on scalar inputs (#10780) * fix reduce crash on scalar inputs * fix uncovered cases. * fix on different opset to pass ci --- python/tvm/relay/frontend/onnx.py | 18 ++++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 2 ++ 2 files changed, 20 insertions(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index eea50081aa23..04fb17abbb19 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1875,6 +1875,9 @@ def run_calculation(cls, inputs, axis, keepdims): @classmethod def _impl_v1(cls, inputs, attr, params): + if not infer_shape(inputs[0]): # promote scalar to 1-D tensor + inputs[0] = _op.expand_dims(inputs[0], axis=0) + if "axes" in attr: axis = attr.get("axes", 0) else: @@ -1885,6 +1888,9 @@ def _impl_v1(cls, inputs, attr, params): @classmethod def _impl_v12(cls, inputs, attr, params): + if not infer_shape(inputs[0]): # promote scalar to 1-D tensor + inputs[0] = _op.expand_dims(inputs[0], axis=0) + if len(inputs) == 2: if isinstance(inputs[1], _expr.Constant): # Get axis and unpack scalar @@ -1937,6 +1943,9 @@ class ReduceSumSquare(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): + if not infer_shape(inputs[0]): # promote scalar to 1-D tensor + inputs[0] = _op.expand_dims(inputs[0], axis=0) + if "axes" in attr: axis = attr.get("axes", 0) else: @@ -1953,6 +1962,9 @@ class ReduceL1(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): + if not infer_shape(inputs[0]): # promote scalar to 1-D tensor + inputs[0] = _op.expand_dims(inputs[0], axis=0) + if "axes" in attr: axis = attr.get("axes", 0) else: @@ -1969,6 +1981,9 @@ class ReduceL2(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): + if not infer_shape(inputs[0]): # promote scalar to 1-D tensor + inputs[0] = _op.expand_dims(inputs[0], axis=0) + if "axes" in attr: axis = attr.get("axes", 0) else: @@ -1986,6 +2001,9 @@ class ReduceLogSum(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): + if not infer_shape(inputs[0]): # promote scalar to 1-D tensor + inputs[0] = _op.expand_dims(inputs[0], axis=0) + if "axes" in attr: axis = attr.get("axes", 0) else: diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a526da5ca445..91775d27b2de 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1934,6 +1934,8 @@ def verify_reduce_func(func, data, axis, keepdims): ] for func in funcs: + verify_reduce_func(func, np.array(1.0).astype(np.float32), axis=None, keepdims=False) + for keepdims in [True, False]: verify_reduce_func( func, np.random.randn(3, 2, 2).astype(np.float32), axis=None, keepdims=keepdims